/*****************************************************************
        Copyright by Rockefeller University,
can not be reproduced or distributed without written permission of
copyright holder.  Version of October 2003.

Written by Saurabh Sinha (contact person), Erik van Nimwegen, and 
Eric Siggia.

The program stubb (and its relatives) implement an algorithm for
finding likely cis-regulatory modules, described in the following
paper:
"A Probabilistic Method to Detect Regulatory Modules"
by Saurabh Sinha, Erik van Nimwegen and Eric Siggia. 
Eleventh International Conference on Intelligent Systems for
Molecular Biology, Brisbane, Australia, July 2003, pg 292-301.

The file sample/gap_wtmx that comes with this distribution includes 
a sample set of transcription factor weight matrices (PWM's) that 
were reported in :
"Computational detection of genomic cis-regulatory modules applied
to body patterning in the early Drosophila embryo"
by N. Rajewsky, M. Vergassola, U. Gaul and E. Siggia.
BMC Bioinformatics 3 (30) 2002.
******************************************************************/

#include "parameters.h"
#include <math.h>
#include "mynewmat.h"
#include "util.h"

Parameters_H01::Parameters_H01()
{
  _pi = NULL;
  _oldpi = NULL;
  _Cij = NULL;
}

Parameters_H01::~Parameters_H01()
{
  Destroy();
}

Parameters_H01::Parameters_H01(const Parameters_H01 &p)
  :Parameters_H1(p)
{
  Copy(p);
}

Parameters_H01& Parameters_H01::operator=(const Parameters_H01 &p)
{
  if (this != &p) {
    this->Parameters_H1::operator=(p);
    Destroy();
    Copy(p);
  }
  return *this;
}

void Parameters_H01::Copy(const Parameters_H01 &p)
{
  if (p._pi) {
    _pi = new float[_numWM];
    for (int i=0; i<_numWM; i++) {
      _pi[i] = p._pi[i];
    }
  }
  else _pi = NULL;

  if (p._oldpi) {
    _oldpi = new float[_numWM];
    for (int i=0; i<_numWM; i++) {
      _oldpi[i] = p._oldpi[i];
    }
  }
  else _oldpi = NULL;

  if (p._Cij) {
    _Cij = new bool *[_numWM];
    for (int i=0; i<_numWM; i++) {
      _Cij[i] = new bool[_numWM];
      for (int j=0; j<_numWM; j++) {
	_Cij[i][j] = p._Cij[i][j];
      }
    }
  }
  else _Cij = NULL;
}

void Parameters_H01::Destroy()
{
  if (_pi) delete [] _pi;
  if (_oldpi) delete [] _oldpi;
  if (_Cij) {
    for (int i=0; i<_numWM; i++) delete [] _Cij[i];
    delete [] _Cij;
  }
}

void Parameters_H01::UpgradeInitialize(vector<Window *> *wl, WtMxCollection *wmc, Parameters_H0 *init)
{
  Parameters_H1 seed(wmc,0);  // 0 is for bias towards first motif. Doesnt matter, since it'll be reset
  int numWM = init->NumWM();
  for (int i=0; i<numWM; i++) {
    for (int j=0; j<numWM; j++) {
      seed._pij[i][j] = init->_pi[j];
    }
  }
  Initialize(wl,wmc,&seed);

  _pi = new float[numWM];
  _oldpi = new float[numWM];
  for (int i=0; i<numWM; i++) {
    _pi[i] = init->_pi[i];
    _oldpi[i] = _pi[i];
  }

  _Cij = new bool*[numWM];
  for (int i=0; i<numWM; i++) {
    _Cij[i] = new bool[numWM];
    for (int j=0; j<numWM; j++) {
      _Cij[i][j] = false;
    }
  }
}
				       
void Parameters_H01::EnableCorrelation(int *flatlist, int size)
{
  int pos = 0;
  for (int i=0; i<size; i++) {
    int l = flatlist[pos++];
    int r = flatlist[pos++];
    _Cij[l][r] = true;
  }

  // also adjust the _pij's accordingly, using _pij (j \in C_i), _pj (j not in C_i)
  AdjustProbabilityParameters();
}

void Parameters_H01::Train(bool differential)
{
  this->Parameters_H1::Train(differential);
}

void Parameters_H01::Train(char *corr_file, bool differential)
{
  if (!_is_initialized) {
    printf("Training attempted without initialization\n");
    _is_trained = false;
    return;
  }

  // Read corr_file and update 

  FILE *fp = fopen(corr_file, "r");
  char line[1024];
  bool seen_corr = false;
  while (!seen_corr && fgets(line,1023,fp)) {
    if (!strstr(line,"Correlated probabilities")) continue;
    else {
      while (fgets(line,1023,fp)) {
	float pr;
	if (strstr(line,"Uncorrelated probabilities")) {
	  int i = 0;
	  char *ptr = strstr(line,":");
	  while ((ptr = strstr(ptr," ")) != NULL) {
	    sscanf(ptr,"%f",&pr);
	    _pi[i++] = pr;
	    ptr++;
	    if (i == _numWM) break;
	  }
	  if (i != _numWM) {
	    printf("Error: Corr file not complete\n");
	    exit(1);
	  }
	  continue;
	}
	if (line[0]=='<') {
	  seen_corr = true;
	  break;
	}

	int left,right;
	if (sscanf(line,"%d -> %d: %f",&left,&right,&pr) < 3) break;
	_pij[left][right] = pr;
	_Cij[left][right] = true;
      }
    }
  }
  fclose(fp);

  // also adjust the _pij's accordingly, using _pij (j \in C_i), _pj (j not in C_i)
  AdjustProbabilityParameters();

  // Update the internal variables
  DoDynamicProgramming();

  // Compute Free Energy
  _free_energy = EvaluateFreeEnergy();  

  _num_iterations = 1;        // record how many iterations were needed

  if (differential) {
    _free_energy_differential = EvaluateFreeEnergyBackground() - _free_energy;
    if (_free_energy_differential < 0) _free_energy_differential = 0;
  }
  else _free_energy_differential = 0;

  int numSpecific = 0;
  int numTotal = 0;
  for (int i=0; i<_windows->size(); i++) {
    Window *win = (*_windows)[i];
    numSpecific += win->NumSpecificCharacters();
    numTotal += win->Length();
  }

#ifdef _NORMALIZE_SPACERS
  if (numSpecific == 0) _free_energy_differential = 0;
  else _free_energy_differential *= float(numTotal)/float(numSpecific);
#endif

  _free_energy_perlength = _free_energy_differential / float(numTotal);

  _is_trained = true;
}

void Parameters_H01::AdjustProbabilityParameters()
{
  for (int i=0; i<_numWM; i++) {
    float pijcorr = 0, pjuncorr = 0;
    for (int j=0; j<_numWM; j++) {
      if (_Cij[i][j]) pijcorr += _pij[i][j];
      else pjuncorr += _pi[j];
    }
    for (int j=0; j<_numWM; j++) {
      if (!_Cij[i][j]) _pij[i][j] = (_pi[j]*(1-pijcorr)/pjuncorr);
    }
  }
}

void Parameters_H01::PrintProbabilities(FILE *fp, bool verbose)
{
  fprintf(fp,"Correlated probabilities:\n");
  for (int i=0; i<_numWM; i++) {
    for (int j=0; j<_numWM; j++) {
      if (_Cij[i][j]) 
	fprintf(fp,"%d -> %d: %.4f\n",i,j,_pij[i][j]);
    }
  }
  fprintf(fp,"Uncorrelated probabilities: ");
  for (int i=0; i<_numWM; i++) {
    fprintf(fp,"%.4f ",_pi[i]);
  }
  fprintf(fp,"\n");
}

void Parameters_H01::Revert()
{
  this->Parameters_H1::Revert();
  for (int i=0; i<_numWM; i++) {
    _pi[i] = _oldpi[i];
  }
}

void Parameters_H01::UpdateTransitionProbabilities()
{
  int i;

  /* First create the average counts ... the fringe correction may matter */
  DTYPE  **Aij = new DTYPE  *[_numWM];
  for (i=0; i<_numWM; i++) {
    Aij[i] = new DTYPE [_numWM];
    for (int j=0; j<_numWM; j++) {
      Aij[i][j] = 0;
      int numWindows = _windows->size();
      for (int wi=0; wi<numWindows; wi++) {
	DTYPE  sum = _Aij[wi][i][j];
	Aij[i][j] += sum*FringeCorrectionFactor(wi);;
      }      
    }
  }

  /* also create some variables to be used in both kinds of updates */
  int **C_i = new int *[_numWM]; 
  int **C_i_bar = new int *[_numWM];
  int *sizeCi = new int [_numWM];
  int *sizeCibar = new int [_numWM];
  int *totalSize = new int [_numWM];
  DTYPE  *sumAijbar = new DTYPE [_numWM];
  for (i=0; i<_numWM; i++) {
    C_i[i] = new int[_numWM];
    C_i_bar[i] = new int[_numWM];
    sizeCi[i] = 0; sizeCibar[i] = 0; totalSize[i] = 0;
    sumAijbar[i] = 0;                   // this will store the \sum_{k \in C_i^{bar}} A_{ik}

    for (int j=0; j<_numWM; j++) {
      if (_Cij[i][j] && Aij[i][j] < SMALL_FLOAT) {    // if this is zero, this is not to be considered part of C_i
	_pij[i][j] = 0;
	continue;
      }
                                        // now, update the sets C_i and C_i_bar
      if (_Cij[i][j]) C_i[i][sizeCi[i]++] = j;
      else {
	C_i_bar[i][sizeCibar[i]++] = j;
	sumAijbar[i] += Aij[i][j];
      }
      totalSize[i]++;
    }
  }

  /* update the correlated probabilities using only the Aij's */
  for (i=0; i<_numWM; i++) {
    if (sizeCi[i]==0) continue;            // no pij parameters to be trained for this i
    if (sizeCi[i]==totalSize[i] || sumAijbar[i] < SMALL_FLOAT) {    // Ci_bar = NULL or immaterial
      DTYPE  sumAij = 0;
      for (int j=0; j<sizeCi[i]; j++) sumAij += Aij[i][C_i[i][j]];
      for (int j=0; j<sizeCi[i]; j++) _pij[i][C_i[i][j]] = Aij[i][C_i[i][j]]/sumAij;
      continue;
    }
    // in between case: at least one but not all j's are correlated
    if (sizeCi[i]==1) {                    // no need to set up system of equations, it is only one equation
      _pij[i][C_i[i][0]] = 1/(1+sumAijbar[i]/Aij[i][C_i[i][0]]);
      continue;
    }
    else {                              // set up a system of sizeCi equations
      int n = sizeCi[i];
      DTYPE  sumAik = 0;                // this has to be positive, otherwise matrix M will be singular !
                                        // it is ensured to be positive by a check above
     	
      Matrix m(n,n);
      for (int j=0; j<n; j++) {
	for (int k=0; k<n; k++) {
	  DTYPE  val = 1;
	  if (j==k) val += sumAijbar[i]/Aij[i][C_i[i][j]];
	  m(j+1,k+1) = val;
	}
      }

      ColumnVector b(n);
      for (int j=0; j<n; j++) b(j+1) = 1.0;

      ColumnVector x = m.i() * b;
        
#ifdef _DEBUG
      printf("Computed x = \n");
      for (int j=0; j<n; j++) printf("%g\n",x(j+1));
#endif
      for (int j=0; j<n; j++) _pij[i][C_i[i][j]] = x(j+1);
    }
  }

  /* then update the uncorrelated probabilities using only the Aij's (and not the newly computed pij's */
  int n = _numWM;
  // first compute the u's
#ifdef _DEBUG
  printf("Aij =\n");
  for (i=0; i<n; i++) {
    for (int j=0; j<n; j++) 
      printf("%.3f ",Aij[i][j]);
    printf("\n");
  }
  printf("u=\n");
#endif
  DDTYPE  *u = new DDTYPE [n];
  for (i=0; i<n; i++) {
    u[i] = log(_pi[i]);
#ifdef _DEBUG
    printf("%g ",u[i]);
#endif
  }
#ifdef _DEBUG
  printf("\n");
#endif

  // then Newton's method to find the best update of u
  DDTYPE  previous_least_squares = INF_LEAST_SQUARES;
  DDTYPE  current_least_squares;
  int count = 0;
  bool newton_failed = false;

  // Allocate the temporary matrices needed by the loop below
  ColumnVector gradFu(n);
  SymmetricMatrix hessianFu(n);
  Matrix V(n,n);
  DiagonalMatrix Lambda(n);
  SymmetricMatrix WORK(n);

  do {
#ifdef _DEBUG
    printf("scaled u=\n");
    for (i=0; i<n; i++) printf("%g ",u[i]);
    printf("\n");
#endif
    // Compute grad F
    for (i=0; i<n; i++) gradFu(i+1) = GradF(u,i,C_i_bar,sizeCibar,Aij,n);
#ifdef _DEBUG
    printf("grad=\n");
    for (i=0; i<n; i++) printf("%g\n",gradFu(i+1));
#endif
    // Evaluate current least squares and check for convergence
    current_least_squares = 0;
    for (i=0; i<n; i++) current_least_squares += pow(gradFu(i+1),2);
    current_least_squares = sqrt(current_least_squares);

#ifdef _DEBUG 
    printf("grad B norm (2) = %g\n",current_least_squares);
#endif
    if (current_least_squares < LEAST_SQUARES_THRESHOLD) {
#ifdef _DEBUG
      printf("Newton Converged after %d iterations\n",count);
#endif
      break;
    }
    previous_least_squares = current_least_squares;

    // Compute Hessian
    for (i=0; i<n; i++) {
      for (int j=i; j<n; j++) {
	DDTYPE  val = HessianF(u,i,j,C_i_bar,sizeCibar,Aij,n);
	hessianFu(j+1,i+1) = val;
      }
    }
#ifdef _DEBUG
    printf("hessian=\n");
    for (i=0; i<n; i++) {
      for (int j=0; j<n; j++) {
	if (i<j) printf("%g ",hessianFu(j+1,i+1));
	else printf("%g ",hessianFu(i+1,j+1));
      }
      printf("\n");
    }
#endif

    // Do SVD of Hessian to get \lambda and V
    Jacobi(hessianFu,Lambda,WORK,V);
#ifdef _DEBUG
    printf("V=\n");
    for (i=0; i<n; i++) {
      for (int j=0; j<n; j++) {
	printf("%g ",V(i+1,j+1));
      }
      printf("\n");
    }
    printf("Lambda=\n");
    for (i=0; i<n; i++) printf("%g\n",Lambda(i+1));
#endif

    DDTYPE  max_abs_lambda = -1; 
    for (i=0; i<n; i++) {
      if (abs(Lambda(i+1)) > max_abs_lambda) max_abs_lambda = abs(Lambda(i+1));
    }
    if (max_abs_lambda < LAMBDA_THRESHOLD) {
      newton_failed = true;
      break;
    }
    for (i=0; i<n; i++) {
      if (abs(Lambda(i+1)) < LAMBDA_THRESHOLD) Lambda(i+1) = 0;
    }

    // Solve hessianFu*(change_u) = -gradFu
    ColumnVector change_u(n);
    gradFu = -gradFu;
    Matrix hessianFuinverse(n,n); hessianFuinverse = 0.0;
    for (i=0; i<n; i++) {
      if (abs(Lambda(i+1)) < LAMBDA_THRESHOLD) continue;
      hessianFuinverse += V.Column(i+1)*V.Column(i+1).t()*(1/Lambda(i+1));
    }
    change_u = hessianFuinverse * gradFu;
#ifdef _DEBUG
    printf("change_u=\n");
    for (i=0; i<n; i++) printf("%g\n",change_u(i+1));
#endif
    // update u and loop back ... also, scale the u's additively
#ifdef _DEBUG
    printf("new u=\n");
#endif
    DDTYPE  max_u = -10000;
    for (i=0; i<n; i++) { 
      u[i] += change_u(i+1);
      if (u[i] > max_u) max_u = u[i];
#ifdef _DEBUG
      printf("%g ",u[i]);
#endif
    }
#ifdef _DEBUG
    printf("\n");
#endif
    if (abs(max_u) > SMALL_FLOAT) for (i=0; i<n; i++) u[i] -= max_u;

    count++;
  } while (count < MAX_NEWTON_ITERATIONS); 

  // Check if there was convergence or not; exit if not
  if (newton_failed || count == MAX_NEWTON_ITERATIONS) {
#ifdef _DEBUG
    printf("Newton Failed To Converge\n");
#endif
#ifdef WARNINGS
    char warning[10000];
    char seqname[1024];
    (*_windows)[0]->Seq()->Name(seqname);
    sprintf(warning,"Warning: Window %d (Sequence %s): Newton's method couldnt converge (iterations=%d)\n",(*_windows)[0]->Start(),seqname,count);
    Warn(warning);
#endif
  }
  else {
    // update the probabilities and normalize
    DTYPE  sumpi = 0;
    for (i=0; i<n; i++) {
      _pi[i] = exp(u[i]);
      sumpi += _pi[i];
    }
    for (i=0; i<n; i++) _pi[i] /= sumpi;
  }

  delete [] u;

  /* also adjust the _pij's (j not in C_i) accordingly, using _pij (j \in C_i), _pj (j not in C_i) */
  AdjustProbabilityParameters();

  /* Clean up of course */
  delete [] totalSize;
  delete [] sizeCi;
  delete [] sizeCibar;
  delete [] sumAijbar;
  for (i=0; i<_numWM; i++) {
    delete [] C_i[i];
    delete [] C_i_bar[i];
  }
  delete [] C_i;
  delete [] C_i_bar;
  for (i=0; i<_numWM; i++) delete [] Aij[i];
  delete [] Aij;
}

void Parameters_H01::Update()
{
#ifdef _DEBUG
  printf("Updating for window %d\n",(*_windows)[0]->Start());
#endif
  // UpdateEmissionProbabilities(); 
  UpdateTransitionProbabilities();
#ifdef _DEBUG
  printf("Done updating for window %d\n",(*_windows)[0]->Start());
#endif
}


void Parameters_H01::ForceCorrelation(int i, int j, float strength)
{
#ifdef WARNINGS
    Warn("Warning: ForceCorrelation called on mixed model, cannot enforce\n");
#endif
  return;
}

char *Parameters_H01::CreateSequence(int length, int seed, bool verbose)
{
  printf("Error: CreateSequence not yet supported for mixed model\n");
  exit(1);
}

DDTYPE  Parameters_H01::GradF(DDTYPE  *u, int l, int **C_i_bar, int *sizeCibar, DTYPE  **Aij, int N)
{
  DDTYPE  grad = 0;
  for (int i=0; i<N; i++) {
    bool l_in_Cibar = false;
    for (int j=0; j<sizeCibar[i]; j++) if (C_i_bar[i][j]==l) l_in_Cibar = true;
    if (!l_in_Cibar) continue;
    DDTYPE  sumeuk = 0, sumAij = 0;
    for (int j=0; j<sizeCibar[i]; j++) {
      sumeuk += exp(u[C_i_bar[i][j]]);
      DDTYPE aij = (DDTYPE)Aij[i][C_i_bar[i][j]];
      if (aij < SMALL_FLOAT) aij = 0;
      sumAij += aij;
    }
    if (sumeuk < SMALL_DOUBLE) grad += Aij[i][l]; else
    grad += (Aij[i][l] - sumAij*(exp(u[l])/sumeuk));
  }
  return grad;
}

DDTYPE  Parameters_H01::HessianF(DDTYPE  *u, int l, int m, int **C_i_bar, int *sizeCibar, DTYPE  **Aij, int N)
{
  DDTYPE  hessian = 0;
  for (int i=0; i<N; i++) {
    bool l_in_Cibar = false, m_in_Cibar = false;
    for (int j=0; j<sizeCibar[i]; j++) {
      if (C_i_bar[i][j]==l) l_in_Cibar = true;
      if (C_i_bar[i][j]==m) m_in_Cibar = true;
    }
    if (!l_in_Cibar || !m_in_Cibar) continue;
    DDTYPE  sumeuk = 0, sumAij = 0;
    for (int j=0; j<sizeCibar[i]; j++) {
      sumeuk += exp(u[C_i_bar[i][j]]);
      DDTYPE aij = (DDTYPE)Aij[i][C_i_bar[i][j]];
      if (aij < SMALL_FLOAT) aij = 0;
      sumAij += aij;
    }
    if (sumeuk < SMALL_DOUBLE) continue;
    hessian += sumAij*(exp(u[l])/sumeuk)*((exp(u[m])/sumeuk) - (l==m?1:0));
  }
  return hessian;
}


