/*****************************************************************
        Copyright by Rockefeller University,
can not be reproduced or distributed without written permission of
copyright holder.  Version of October 2003.

Written by Saurabh Sinha (contact person), Erik van Nimwegen, and 
Eric Siggia.

The program stubb (and its relatives) implement an algorithm for
finding likely cis-regulatory modules, described in the following
paper:
"A Probabilistic Method to Detect Regulatory Modules"
by Saurabh Sinha, Erik van Nimwegen and Eric Siggia. 
Eleventh International Conference on Intelligent Systems for
Molecular Biology, Brisbane, Australia, July 2003, pg 292-301.

The file sample/gap_wtmx that comes with this distribution includes 
a sample set of transcription factor weight matrices (PWM's) that 
were reported in :
"Computational detection of genomic cis-regulatory modules applied
to body patterning in the early Drosophila embryo"
by N. Rajewsky, M. Vergassola, U. Gaul and E. Siggia.
BMC Bioinformatics 3 (30) 2002.
******************************************************************/

#include "wtmx.h"
#include <math.h>

WtMx::WtMx(float **w, int len, char *nm, float psd)
{
  _wtmx = new float **[1];
  _wtmx[0] = new float *[4];
  int i;
  for (i=0; i<4; i++) {
    _wtmx[0][i] = new float[len];
    for (int j=0; j<len; j++) {
      _wtmx[0][i][j] = w[i][j];
    }
  }
  _length = len;
  _pseudo_count = psd;

  _Morder = 0;
  _powMorder = 1;
  _basicwtmx = NULL;
  _higherwtmx = NULL;

  if (nm != NULL) {
    strcpy(_name,nm);
  }
  else {
    sprintf(_name,"wtmx_%d",globalid++);
  }

  _is_normalized = false;
  Normalize();

  _udata = -1;
  _forward_bias = 0.5;
  _is_special = false;
}

WtMx::WtMx(float ***w, float **bw, int len, char *nm, int Morder, float psd)
{
  _length = len;
  _Morder = Morder;
  _powMorder = (int)(pow(4,_Morder));

  if (nm != NULL) {
    strcpy(_name,nm);
  }
  else {
    sprintf(_name,"wtmx_%d",globalid++);
  }
  _pseudo_count = psd;

  _wtmx = new float **[_powMorder];
  for (int p=0; p<_powMorder; p++) {
    _wtmx[p] = new float *[4];
    for (int i=0; i<4; i++) {
      _wtmx[p][i] = new float[_length];
      for (int j=0; j<_length; j++) {
	_wtmx[p][i][j] = w[p][i][j];
      }
    }
  }

  if (_Morder==0) {
    _basicwtmx = NULL;
    _higherwtmx = NULL;
    _is_normalized = false;
    Normalize();
  }
  else {
    _basicwtmx = new float *[4];
    for (int i=0; i<4; i++) {
      _basicwtmx[i] = new float[_length];
      for (int j=0; j<_length; j++) {
	_basicwtmx[i][j] = bw[i][j];
      }
    }
    if (_length > 1) {
      printf("Non unit-length wtmx with higher order not supported\n");
      exit(1);
    }
    _is_normalized = false;
    Normalize();

    _higherwtmx = new float[_powMorder];
    for (int i=0; i<_powMorder; i++) {
      float term = 1; int a = i;
      for (int j=0; j<_Morder; j++) {
	term *= _basicwtmx[a%4][0];
	a /= 4;
      }
      _higherwtmx[i] = term;
    }
  }

  _udata = -1;
  _forward_bias = 0.5;
  _is_special = false;
}

WtMx::WtMx(char ch)
{
  if (ch == 'N') {
    _is_special = true;
    _special_char = 'N';
    _is_normalized = true;
    _length = 1;
    strcpy(_name,"_special_wtmx_N");
    _wtmx = NULL;
    _basicwtmx = NULL;
    _higherwtmx = NULL;
    _Morder = 0;
    _powMorder = 1;
    _udata = -1;
    _pseudo_count = PSEUDO_COUNT;
    _forward_bias = 0.5;

  }
  else {
    printf("No special characters other than N supported yet\n");
    exit(1);
  }
}

WtMx::WtMx(WtMx *w)
{
  if (w->_is_special) {
    _is_special = true;
    _special_char = w->_special_char;
    _is_normalized = true;
    _length = 1;
    strcpy(_name,"_special_wtmx_N");
    _wtmx = NULL;
    _basicwtmx = NULL;
    _higherwtmx = NULL;
    _Morder = 0;
    _powMorder = 1;
    _udata = -1;
    _forward_bias = 0.5;
    _pseudo_count = w->_pseudo_count;
    return;
  }

  _Morder = w->_Morder;
  _powMorder = w->_powMorder;
  
  _length = w->_length;
  _wtmx = new float **[_powMorder];
  for (int p=0; p<_powMorder; p++) {
    _wtmx[p] = new float *[4];
    
    for (int i=0; i<4; i++) {
      _wtmx[p][i] = new float[_length];
      for (int j=0; j<_length; j++) {
	_wtmx[p][i][j] = w->_wtmx[p][i][j];
      }
    }
  }

  if (_Morder==0) {
    _basicwtmx = NULL;
    _higherwtmx = NULL;
  }
  else {
    _basicwtmx = new float *[4];
    for (int i=0; i<4; i++) {
      _basicwtmx[i] = new float[_length];
      for (int j=0; j<_length; j++) {
	_basicwtmx[i][j] = w->_basicwtmx[i][j];
      }
    }
    if (_length > 1) {
      printf("Non unit-length wtmx with higher order not supported\n");
      exit(1);
    }
    _higherwtmx =  new float[_powMorder];
    for (int i=0; i<_powMorder; i++) {
      _higherwtmx[i] = w->_higherwtmx[i];
    }
  }

  _is_normalized = w->_is_normalized;
  strcpy(_name,w->_name);

  _udata = w->_udata;
  _pseudo_count = w->_pseudo_count;
  _forward_bias = w->_forward_bias;
  _is_special = false;
}

WtMx::~WtMx()
{
  if (_wtmx!=NULL) {
    for (int p=0; p<_powMorder; p++) {
      if (_wtmx[p]!=NULL) {
	for (int i=0; i<4; i++) {
	  if (_wtmx[p][i]!=NULL) delete [] _wtmx[p][i];
	}
	delete [] _wtmx[p];
      }
    }
    delete [] _wtmx;
  }
  if (_basicwtmx!=NULL) {
    for (int i=0; i<4; i++) {
      if (_basicwtmx[i]!=NULL) delete [] _basicwtmx[i];
    }
    delete [] _basicwtmx;
  }
  if (_higherwtmx!=NULL) delete [] _higherwtmx;
}

#define abs(x) (x>0?x:-x)

void WtMx::Normalize()
{
  if (_is_normalized) return;
  for (int p=0; p<_powMorder; p++) {
    for (int i=0; i<_length; i++) {
      float sum = 0;
      int j;
      for (j=0; j<4; j++) {
	sum += _wtmx[p][j][i];
      }
      for (j=0; j<4; j++) {
	_wtmx[p][j][i] = (_wtmx[p][j][i] + _pseudo_count)/(sum+4*_pseudo_count);
      }
    }
  }

  if (_basicwtmx != NULL) {
    for (int i=0; i<_length; i++) {
      float sum = 0;
      int j;
      for (j=0; j<4; j++) {
	sum += _basicwtmx[j][i];
      }
      for (j=0; j<4; j++) {
	_basicwtmx[j][i] = (_basicwtmx[j][i] + PSEUDO_COUNT)/(sum+4*PSEUDO_COUNT); // use the default pseudo counts here
      }
    }
  }

  _is_normalized = true;
}

void WtMx::Print(FILE *fp)
{
  fprintf(fp,">%s %d\n",_name,_length);
  return;
  if (!_is_special && _powMorder==1) {
    for (int i=0; i<4; i++) {
      for (int j=0; j<_length; j++) {
	fprintf(fp,"%.4f ",_wtmx[0][i][j]);
      }
      fprintf(fp,"\n");
    }
  }
  else {
    fprintf(fp,"Cannot print non-standard weight matrix\n");
  }
  fprintf(fp,"<\n");
}

void WtMx::Name(char *str, int pad_to)
{
  strcpy(str,_name);
  int len = strlen(str);
  if (len < pad_to) { 
    for (int j=len; j<pad_to; j++) {
      str[j] = ' ';
    }
    str[pad_to] = 0;
  }
}

float WtMx::Frequency(int offset, int index)
{
  if (index>3) return (_is_special && _special_char=='N'?1:0);
  if (_is_special) return 0;
  if (_powMorder != 1) {
    float retval = _basicwtmx[index][offset];
    if (retval <= 0) return SMALL_FREQUENCY;
    return retval;
  }
  else {
    float retval = _wtmx[0][index][offset];
    if (retval <= 0) return SMALL_FREQUENCY;
    return retval;
  }
}

float WtMx::HigherOrderFrequency(int offset, int index)
{
  return _higherwtmx[index];
}

float WtMx::Frequency(int offset, int index, int history)
{
  if (_powMorder==1) return Frequency(offset,index); // dont worry about history word

  if (index>3) return (_is_special && _special_char=='N'?1:0);
  if (_is_special) return 0;
  if (history >= _powMorder) {
    printf("Error: Frequency requested for Markov matrix, not enough information to handle history\n");
    exit(1);
  }
  if (history < 0) {
    printf("Error: Negative history provided to Frequency\n");
    exit(1);
    float retval = _basicwtmx[index][offset];
    if (retval <= 0) return SMALL_FREQUENCY;
    return retval;
  }
  else {
    float retval = _wtmx[history][index][offset];
    if (retval <= 0) return SMALL_FREQUENCY;
    return retval;
  }
}

float WtMx::Frequency(int offset, int index, char *seq, int history_length)
{
  if (_powMorder==1) return Frequency(offset, index); // dont worry about history word

  if (index>3) return (_is_special && _special_char=='N'?1:0);
  if (_is_special) return 0;
  if (_Morder != history_length) {
    printf("Error: Markov order of weight matrix doesnt match the history provided\n");
    exit(1);
  }

  int history = 0;
  for (int i=0; i<history_length; i++) {
    int ch;
    switch(*seq) {
    case 'a':
    case 'A': ch = 0; break;
    case 'c':
    case 'C': ch = 1; break;
    case 'g':
    case 'G': ch = 2; break;
    case 't': 
    case 'T': ch = 3; break;
    default: ch = 4; break;
    }
    if (ch >= 4) {
      history = -1; 
      break;
    }
    history = history*4 + ch;
    seq++;
  }
  if (history < 0) {
    float retval = _basicwtmx[index][offset];
    if (retval <= 0) return SMALL_FREQUENCY;
    return retval;
  }
  else {
    float retval = _wtmx[history][index][offset];
    if (retval <= 0) return SMALL_FREQUENCY;
    return retval;
  }
}

char WtMx::GetRandomChar(int offset)
{
  if (offset > _length) {
    printf("Error: random char requested at non-existing position of weight matrix\n");
    exit(1);
  }
  if (_is_special) {
    printf("Error: random char requested for special matrix\n");
    exit(1);
  }
  if (_powMorder != 1) {
    printf("Error: random char requested for Markov matrix, no history provided\n");
    exit(1);
  }
  float *cum = new float[4];
  cum[0] = _wtmx[0][0][offset];
  for (int i=1; i<4; i++) {
    cum[i] = cum[i-1] + _wtmx[0][i][offset];
  }
  float r = random()/float(RAND_MAX);
  if (r < cum[0]) return 'A';
  if (r < cum[1]) return 'C';;
  if (r < cum[2]) return 'G';
  return 'T';
}

char WtMx::GetRandomChar(int offset, int history)
{
  if (_powMorder==1) return GetRandomChar(offset);

  if (offset > _length) {
    printf("Error: random char requested at non-existing position of weight matrix\n");
    exit(1);
  }
  if (_is_special) {
    printf("Error: random char requested for special matrix\n");
    exit(1);
  }
  if (history < 0 || history >= _powMorder) {
    printf("Error: random char requested for Markov matrix, incorrect history provided\n");
    exit(1);
  }
  float *cum = new float[4];
  cum[0] = _wtmx[history][0][offset];
  for (int i=1; i<4; i++) {
    cum[i] = cum[i-1] + _wtmx[history][i][offset];
  }
  float r = random()/float(RAND_MAX);
  if (r < cum[0]) return 'A';
  if (r < cum[1]) return 'C';;
  if (r < cum[2]) return 'G';
  return 'T';
}

char WtMx::GetRandomChar(int offset, char *seq, int history_length)
{
  if (_powMorder==1) return GetRandomChar(offset);

  if (offset > _length) {
    printf("Error: random char requested at non-existing position of weight matrix\n");
    exit(1);
  }
  if (_is_special) {
    printf("Error: random char requested for special matrix\n");
    exit(1);
  }
  if (_Morder != history_length) {
    printf("Error: Markov order of weight matrix doesnt match the history provided\n");
    exit(1);
  }

  int history = 0;
  for (int i=0; i<history_length; i++) {
    int ch;
    switch(*seq) {
    case 'a':
    case 'A': ch = 0; break;
    case 'c':
    case 'C': ch = 1; break;
    case 'g':
    case 'G': ch = 2; break;
    case 't': 
    case 'T': ch = 3; break;
    default: ch = 4; break;
    }
    if (ch >= 4) {
      history = -1; 
      break;
    }
    history = history*4 + ch;
    seq++;
  }
  float *cum = new float[4];
  if (history < 0) {
    cum[0] = _basicwtmx[0][offset];
    for (int i=1; i<4; i++) {
      cum[i] = cum[i-1] + _basicwtmx[i][offset];
    }
  }
  else {
    cum[0] = _wtmx[history][0][offset];
    for (int i=1; i<4; i++) {
      cum[i] = cum[i-1] + _wtmx[history][i][offset];
    }
  }
  float r = random()/float(RAND_MAX);
  if (r < cum[0]) return 'A';
  if (r < cum[1]) return 'C';;
  if (r < cum[2]) return 'G';
  return 'T';
}

int WtMx::MarkovOrder()
{
  return _Morder;
}

int WtMx::Length()
{
  return _length;
}

int WtMx::GetUserData()
{
  return _udata;
}

void WtMx::SetUserData(int d)
{
  _udata = d;
}

void WtMx::SetForwardBias(float bias)
{
  _forward_bias = bias;
}

float WtMx::GetForwardBias()
{
  return _forward_bias;
}

void WtMx::UpdateFrequency(int offset, DTYPE *freq)
{
  if (_is_special) {
    printf("Tried updating frequencies of a special matrix\n");
    exit(1);
  }
  if (_powMorder != 1) {
    printf("Tried updating frequencies of a Markov matrix\n");
    exit(1);
  }
  for (int i=0; i<4; i++) {
    _wtmx[0][i][offset] = (float)(freq[i]);
  }
  return;
}

WtMxCollection::WtMxCollection()
{
  _numValid = 0;
}

WtMxCollection::WtMxCollection(char *flname)
{
  FILE *fp = fopen(flname,"r");
  if (fp == NULL) {
    printf("Error: couldnt open file %s\n",flname);
    exit(1);
  }

  _numValid = 0;
  char line[1024];
  int  length; char name[1024];
  float **w = new float *[4];
  int i;
  
  while (fgets(line,1023,fp)) {
    if (line[0] != '>' && line[0] != '#') {
      printf("Error: Weight Matrix File format not recognized\n");
      exit(1);
    }
    if (line[0]=='>') {
      if (sscanf(line,">%s %d",name,&length) != 2) {
	printf("Error reading weight matrix file\n");
	exit(1);
      }
      float pseudo_count = PSEUDO_COUNT;
      char *psd = strstr(line,"PSEUDO_COUNT");
      char dummy[1024];
      if (psd) {
	sscanf(psd,"%s %f",dummy,&pseudo_count);
      }
      for (i=0; i<4; i++) {
	w[i] = new float[length];
      }
      for (i=0; i<length; i++) {
	if (!fgets(line,1023,fp) || sscanf(line,"%f %f %f %f",&(w[0][i]),&(w[1][i]),&(w[2][i]),&(w[3][i])) != 4) {
	  printf("Error reading weight matrix file\n");
	  exit(1);
	}
      }
      WtMx *wm = new WtMx(w,length,name,pseudo_count);
      for (i=0; i<4; i++) {
	delete [] w[i];
      }

      Add(wm);

      if (!fgets(line,1023,fp) || line[0] != '<') {
	printf("Error reading weight matrix file\n");
	exit(1);
      }
    }
    if (line[0]=='#') {
      char dummy[1024];
      sscanf(line,"%s %s",dummy,name);
      char *ln = strstr(line,"len=");
      if (ln == NULL) {
	printf("Error reading weight matrix file\n");
	exit(1);
      }
      sscanf(ln,"%s %d\n",dummy,&length);
      if (length < 0) {
	printf("Error reading weight matrix file\n");
	exit(1);
      }

      float pseudo_count = PSEUDO_COUNT;
      char *psd = strstr(line,"PSEUDO_COUNT");
      if (psd) {
	sscanf(psd,"%s %f",dummy,&pseudo_count);
      }

      for (i=0; i<4; i++) {
	w[i] = new float[length];
      }
      for (int a=0; a<4; a++) {
	fscanf(fp,"%s",dummy); // the first char
	for (i=0; i<length; i++) {
	  fscanf(fp,"%f ",&(w[a][i]));
	}
      }

      WtMx *wm = new WtMx(w,length,name,pseudo_count);
      for (i=0; i<4; i++) {
	delete [] w[i];
      }

      Add(wm);
    } 
  }

  fclose(fp);
  delete [] w;
}

WtMxCollection::~WtMxCollection()
{
  for (int i=0; i<_numValid; i++) {
   WtMx *w = _vec[i];
   if (w != NULL) delete w;
  }
  _vec.clear();
}

int WtMxCollection::Add(WtMx *w)
{
  _vec.push_back(w);
  _valid.push_back(_vec.size()-1);
  return _numValid++;
}

WtMx *WtMxCollection::Remove(int index)
{
  WtMx *w = _vec[_valid[index]];
  for (int i=index; i<_numValid-1; i++) _valid[i] = _valid[i+1];
  _valid.pop_back();
  _numValid--;
  return w;
}

WtMx *WtMxCollection::WM(int index)
{
  return _vec[_valid[index]];
}

int WtMxCollection::Size()
{
  return _numValid;
}

int WtMxCollection::TotalSize()
{
  return _vec.size();
}

int WtMxCollection::MaxLength()
{
  int max_length = -1;
  for (int i=0; i<Size(); i++) {
    if (max_length < _vec[_valid[i]]->Length())
      max_length =  _vec[_valid[i]]->Length();
  }
  return max_length;
}

void WtMxCollection::Print()
{
  for (int i=0; i<Size(); i++) {
    _vec[_valid[i]]->Print();
  }
} 



