#include "interspecies.h"
#include <gsl/gsl_rng.h>
#include <math.h>

/* Estimate the paramaters 'chemical potentials' of the prior 
 * distribution over the number of windows and colors. We take
 * as input the window structure and the desired number of windows
 * w and colors c. We then determine at which chemical potentials the
 * peak of the distribution would occur at the desired number of
 * colors c and windows w. To do this we create random states containing
 * w-1 and w windows and count how many free windows there are in these states
 * i.e. if in a state s with w-1 states there are f free windows, and if after adding
 * another window there are f' free windows, then we record X = f * f'/(w(w+1))
 * and we average X over a significant number of different states.
 * We also need to calculate Stirling numbers of the second kind S^w_c 
 * (number of ways of dividing w windows into c colors)
 */

int NUMSAMP = 1000;

/**log-sum rule for Stirling number calculation***/
double get_log_sum(double t1,double t2){
  if(t1 < -1 && t2 < -1){
    return t1+t2;
  }
  else if(t1 < -1){
    return t2;
  }
  else if(t2 < -1){
    return t1;
  }
  else{
    if(t1 > t2){
      return (t1 + log(1.0+exp(t2-t1)));
    }
    else{
      return (t2 + log(1.0+exp(t1-t2)));
    }
  }
}

/***get the lambda chemical potential for the number of colors****/
double get_lambda(double **stirsec,int win,int col){
  
  double res;
  res = 0.5 * (stirsec[win][col+1] - stirsec[win][col-1]);
  return (res);
}


/***puts all the Stirling numbers of the second kind up to S^(win+1)_(c+1) into stirsec***/
void get_stirling_nums(double **stirsec,int win,int col) {

  int w,c;
 
  /***cases where there is one colour**/
  for(w=0;w<=(win+1);++w){
    stirsec[w][1] = 0.0;
  }
  
  for(c=2;c<=(col+1);++c){
    for(w=1;w<=(win+1);++w){
      if(w < c){
	stirsec[w][c] = -100.0;
      }
      else if(w == c){
	stirsec[w][c] = 0.0;
      }
      else {
	stirsec[w][c] = get_log_sum(log(c)+stirsec[w-1][c],stirsec[w-1][c-1]);
      }
    }
  }

  return;
}



int estimatepriorparams(params *v) 
{
  int i,l,m,numfree,thisindex,samp,k;
  double thisx,meanx,varx,mu,lambda,tot;
  window *tempwina, *curwin, *tempwinb;
  window **freewins;
  double **stirsec;


  meanx = 0;
  varx = 0;
  thisx = 0;
  /**total number of windows**/
  l=(v->win)->len;
  /**array with pointers to all existing windows***/
  freewins = (window **) calloc(l,sizeof(window *));

  numfree = 0;
  for(i=0;i<l;++i)
    {
      curwin=&g_array_index(v->win,window,i);
      freewins[i] = curwin;
      /**temporarily use colour to refer to index***/
      curwin->colour = i;
      curwin->blocked = 0;
      ++numfree;
    }
  

  for(samp=0;samp<NUMSAMP;++samp)
    {
      /**fill in deswin windows***/
      for(k=0;k<(v->deswin);++k){
	/**pick a random window from all the free windows**/
	i=gsl_rng_uniform_int(v->gslrand, numfree);
	curwin = freewins[i];
	/**set not free **/
	curwin->blocked = 1;
	/**swap with last in list**/
	tempwina = freewins[numfree-1];
	freewins[numfree-1] = curwin;
	freewins[i] = tempwina;
	tempwina->colour = i;
	--numfree;
	/***go through all blockedwins***/
	for (m=0;m<(curwin->blockedwins->len);++m) 
	  {
	    tempwina = g_ptr_array_index(curwin->blockedwins,m);
	    /**check this window is free***/
	    if(tempwina->blocked == 0) 
	      {
		tempwina->blocked = 1;
		/**remove it from free list**/
		thisindex = tempwina->colour;
		if(thisindex >= numfree)
		  {
		    fprintf(stderr,"error thisindex %d is bigger than numfree %d\n",thisindex,numfree);
		  }
		tempwinb = freewins[numfree-1];
		freewins[numfree-1] = tempwina;
		freewins[thisindex] = tempwinb;
		tempwinb->colour = thisindex;
		--numfree;
	      }
	  }
	/**if running out of free windows prematurely, return error****/
	if(numfree <= 0 && k< (v->deswin-1))
	  {
	    return 1;
	  }
	if(k+1 == ((v->deswin)-1))
	  {
	    thisx = ((double) numfree)/((double) (v->deswin));
	  }
	if(k+1 == (v->deswin))
	  {
	    thisx *= ((double) numfree)/((double) ((v->deswin)+1));
	  }
	
      }
      meanx += thisx;
      varx += thisx*thisx;

      /**go clean all the windows***/
      numfree = 0;
      for(i=0;i<l;++i)
	{
	  curwin=&g_array_index(v->win,window,i);
	  freewins[i] = curwin;
	  /**temporarily use colour to refer to index***/
	  curwin->colour = i;
	  curwin->blocked = 0;
	  ++numfree;
	}
    }
  meanx = meanx/((double) NUMSAMP);
  varx = varx/((double) NUMSAMP);
  varx -= meanx*meanx;
  varx = sqrt(varx/((double) NUMSAMP));
  
  /**final cleanup***/
  for(i=0;i<l;++i)
    {
      curwin=&g_array_index(v->win,window,i);
      /**temporarily use colour to refer to index***/
      curwin->colour = 0;
      curwin->blocked = 0;
    }
  free(freewins);

  /**now go calculate the Stirling numbers of the second kind*****/
   /**memory allocation***/
  stirsec = (double **) calloc((v->deswin)+2,sizeof(double *));
  for(i=0;i<=((v->deswin)+1);++i){
    stirsec[i] = (double *) calloc((v->descol)+2,sizeof(double));
  }
  get_stirling_nums(stirsec,(v->deswin),(v->descol));
  mu = 0.5 * log(meanx) + 0.5*(stirsec[(v->deswin)+1][(v->descol)]-stirsec[(v->deswin)-1][(v->descol)]);
  lambda = get_lambda(stirsec,(v->deswin),(v->descol));

  /***if we are using a WM file then we need to correct also for total TF num count ****/
  if((v->priorbinbase)->len > 0)
    {
      tot = (double) (v->numtfs+(v->priorbinbase)->len - v->descol);
      if(tot > 0)
        {
          lambda += 0.5 * log(tot * (tot+1.0));
        }
      /**if the desired number of colors is also the maximum allowed then we simply turn off the chemical potential***/
      else
        {
          lambda = 0.0;
        }
    }

  /**use += when correcting for the score of each window**/
  /*use = when just correcting for state space counts***/
  v->mu = mu;
  v->lambda = lambda;


  /***free stirling numbers of the second kind****/
  for(i=0;i<=((v->deswin)+1);++i){
    free(stirsec[i]);
  }
  free(stirsec);

  return 0;

}
