/*****************************************************************
        Copyright by Rockefeller University,
can not be reproduced or distributed without written permission of
copyright holder.  Version of October 2003.

Written by Saurabh Sinha (contact person), Erik van Nimwegen, and 
Eric Siggia.

The program stubb (and its relatives) implement an algorithm for
finding likely cis-regulatory modules, described in the following
paper:
"A Probabilistic Method to Detect Regulatory Modules"
by Saurabh Sinha, Erik van Nimwegen and Eric Siggia. 
Eleventh International Conference on Intelligent Systems for
Molecular Biology, Brisbane, Australia, July 2003, pg 292-301.

The file sample/gap_wtmx that comes with this distribution includes 
a sample set of transcription factor weight matrices (PWM's) that 
were reported in :
"Computational detection of genomic cis-regulatory modules applied
to body patterning in the early Drosophila embryo"
by N. Rajewsky, M. Vergassola, U. Gaul and E. Siggia.
BMC Bioinformatics 3 (30) 2002.
******************************************************************/

#include "sequence.h"
#include <math.h>

class Alignment *Sequence::_alignments = NULL;

Sequence::Sequence(char *flname)
{
  FILE *fp = fopen(flname,"r");
  if (fp == NULL) {
    printf("Error reading file %s\n",flname);
    exit(1);
  }
  _seq = ReadFasta(fp);
  fclose(fp);

  _length = strlen(_seq);
  _indexseq = new char[_length];
  for (int i=0; i<_length; i++) {
    switch(_seq[i]) {
    case 'a':
    case 'A': _indexseq[i] = 0; break;
    case 'c':
    case 'C': _indexseq[i] = 1; break;
    case 'g': 
    case 'G': _indexseq[i] = 2; break;
    case 't':
    case 'T': _indexseq[i] = 3; break;
    default:  _indexseq[i] = 4;
    } 
  }

  _id = 0; // may be set from outside if relevant
  sprintf(_name,"No Name");
  _alignments = NULL;
  _udata = NULL;
}

Sequence::Sequence(char *seq, int length, char *name)
{
  _seq = new char[length+1];
  strcpy(_seq,seq);
  _length = length;
  _indexseq = new char[_length];
  for (int i=0; i<_length; i++) {
    switch(_seq[i]) {
    case 'a':
    case 'A': _indexseq[i] = 0; break;
    case 'c':
    case 'C': _indexseq[i] = 1; break;
    case 'g': 
    case 'G': _indexseq[i] = 2; break;
    case 't':
    case 'T': _indexseq[i] = 3; break;
    default:  _indexseq[i] = 4;
    } 
  }

  _id = 0;
  strcpy(_name,name);
  _alignments = NULL;
  _udata = NULL;
}

Sequence::~Sequence()
{
  delete [] _seq;
  delete [] _indexseq;
  if (_alignments) {
    delete _alignments;
    _alignments = NULL;
  }
}

char Sequence::CharAt(int index)
{
  return _seq[index];
}

char Sequence::IndexOfCharAt(int index)
{
  return _indexseq[index];
}

bool Sequence::AmbiguousCharAt(int index)
{
  return (IndexOfCharAt(index) > 3);
}

int Sequence::Length()
{
  return _length;
}

char *Sequence::ReadFasta(FILE *fp)
{
  char line[MAX_LINE_LENGTH];
  do {
    fgets(line,MAX_LINE_LENGTH-1,fp);
  } while (line[0] == '#');
  if (line[0] != '>') {
    printf("Error: In Fasta file, each sequence must be preceded by a header line beginning with a '>'\n");
    exit(1);
  }
  char *sequence = new char[MAX_SEQ_LENGTH+1];
  sequence[0] = 0;
  int seq_pos = 0;
  while (fgets(line,MAX_LINE_LENGTH-1,fp)) {
    if (line[0]=='#') continue;
    if (strstr(line,">")) {
      strcpy(&(sequence[seq_pos]),"$");
      seq_pos ++;
      continue;   // if we want all sequences as one
    }

    // chomp
    int last_pos = strlen(line)-1;
    if (line[last_pos]=='\n' || line[last_pos]=='\r') line[last_pos] = 0;

    // concatenate
    strcpy(&(sequence[seq_pos]),line);
    seq_pos += strlen(line);
  }
  char *retval = new char[strlen(sequence)+1];
  strcpy(retval,sequence);
  delete [] sequence;

  return retval;
}

void Sequence::Print(int start, int stop)
{
  for (int i=start; i<=stop; i++) {
    printf("%c",_seq[i]);
  }
  printf("\n");
}

void Sequence::Name(char *name)
{
  strcpy(name,_name);
}

void Sequence::SetSpeciesIndex(int i)
{
  _id = i;
}

int Sequence::GetSpeciesIndex()
{
  return _id;
}

void Sequence::SetUserData(void *udata)
{
  _udata = udata;
}

void *Sequence::GetUserData()
{
  return _udata;
}

Window::Window(Sequence *seq, int start, int stop)
  //PRECONDITION: start and stop must not lie on opposite sides of a "$"
{
  _seq = seq;
  _start = start;
  _stop = stop;
  _length = _stop - _start + 1;
  if (_length < 1) {
    printf("Error: window smaller than 1 long requested\n");
    exit(1);
  }
  _indexofchar = new char[_length];
  _ambiguouschar = new bool[_length];
  _alignmentbdry = new char[_length];

  for (int i=0; i<_length; i++) {
    char index = _seq->IndexOfCharAt(_start+i);
    _indexofchar[i] = index;
    _ambiguouschar[i] = AmbiguousChar(index);
    _alignmentbdry[i] = 0;
  }

#ifdef _MULTIPLE_SEQUENCES
  _arrayofchar = new char *[_length];
#ifdef _MEMORY_HACK
  _arrayofchar[0] = new char[_length*(1+2*MAX_ALIGNED_SEQUENCES)]; 
  char *ptr = _arrayofchar[0];
  for (int i=0; i<_length; i++) {
    _arrayofchar[i] = ptr;
    ptr += (1+2*MAX_ALIGNED_SEQUENCES);    
  }
#else 
  for (int i=0; i<_length; i++) _arrayofchar[i] = new char[(1+2*MAX_ALIGNED_SEQUENCES)]; 
#endif
#endif

}

Window::~Window()
{
  delete [] _indexofchar;
  delete [] _ambiguouschar;
  delete [] _alignmentbdry;
#ifdef _MULTIPLE_SEQUENCES
#ifdef _MEMORY_HACK
  delete [] _arrayofchar[0];
#else 
  for (int i=0; i<_length; i++) delete [] _arrayofchar[i];
#endif
  delete [] _arrayofchar;
#endif
}

bool Window::AmbiguousChar(int index)
{
  return (index > 3);
}

Sequence *Window::Seq()
{
  return _seq;
}

int Window::Start()
{
  return _start;
}

int Window::Stop()
{
  return _stop;
}

int Window::Length()
{
  return _length;
}

char Window::IndexOfCharAt(int index, char *&arrayofchar) 
{
  if (_indexofchar[index] == -1) {
#ifdef _MULTIPLE_SEQUENCES
    arrayofchar = _arrayofchar[index];
    return -1;
#else 
    printf("Error: Negative index of char for single sequence\n");
    exit(1);
#endif
  }
  return _indexofchar[index];
}

char Window::IndexOfCharAtInReferenceSequence(int index) 
{
  if (_indexofchar[index] == -1) {
#ifdef _MULTIPLE_SEQUENCES
    return _arrayofchar[index][2];
#else 
    printf("Error: Negative index of char for single sequence\n");
    exit(1);
#endif
  }
  return _indexofchar[index];
}

bool Window::AmbiguousCharAt(int index)
{
  return _ambiguouschar[index];

  if (_indexofchar[index] == -1) {
#ifdef _MULTIPLE_SEQUENCES
    char *ptr = _arrayofchar[index];
    char numseq = *ptr++;
    for (int i=0; i<numseq; i++) {
      ptr++; 
      if (*ptr > 3) return true;
      ptr++;
    }
    return false;
#else
    printf("Error: Negative index of char for single sequence\n");
    exit(1);
#endif
  }
  return (_indexofchar[index] > 3);
}

bool Window::AmbiguousCharAtInReferenceSequence(int index)
{
  if (_indexofchar[index] == -1) {
#ifdef _MULTIPLE_SEQUENCES
    return (_arrayofchar[index][2] > 3); // third position in cell is the index for reference sequence char
#else
    printf("Error: Negative index of char for single sequence\n");
    exit(1);
#endif
  }
  return (_indexofchar[index] > 3);
}

bool Window::AlignmentBeginsAt(int index)
{
  return (_alignmentbdry[index]==1);
}

bool Window::AlignmentEndsAt(int index)
{
  return (_alignmentbdry[index]==2);
}

Window *Window::Context(int size)
{
  int start = _start - size;
  if (start < 0) start = 0;
  int stop = _stop + size;
  if (stop >= _seq->Length()) stop = _seq->Length()-1;
  return new Window(_seq,start,stop);
}

void Window::AlignWindow(Window *w, int offset)
{
#ifdef _MULTIPLE_SEQUENCES
  int wlen = w->Length();
  int endoffset = offset+wlen;
  if (offset >= _length || endoffset < 0) {
    printf("Error: Cannot align window as requested\n");
    exit(1);
  }

  char wseqindex = (char)w->Seq()->GetSpeciesIndex();

  if (offset >= 0) _alignmentbdry[offset] = 1;
  if (endoffset < _length) _alignmentbdry[endoffset] = 2;

  for (int i=offset; i<endoffset; i++) {
    if (i < 0) continue;
    if (i >= _length) break;
    char wchindex = w->_indexofchar[i-offset];
    if (AmbiguousChar(wchindex)) _ambiguouschar[i] = true;
    if (_indexofchar[i] == -1) {
      char num_seq = _arrayofchar[i][0];
      _arrayofchar[i][1+2*num_seq] = wseqindex;
      _arrayofchar[i][2+2*num_seq] = wchindex;
      _arrayofchar[i][0]++;
    }
    else {
      char *ar = _arrayofchar[i];  // Create the flat list now
      *ar++ = 2;                  // number of sequences
      *ar++ = (char)_seq->GetSpeciesIndex();   // sequence 1 index
      *ar++ = _indexofchar[i];    // sequence 1 char index
      *ar++ = wseqindex;          // sequence 2 index
      *ar++ = wchindex;           // sequence 2 char index
      _indexofchar[i] = -1;
    }
  }    
#else
    printf("Error:  Align Window called on single sequence\n");
    exit(1);
#endif
}

int Window::NumSpecificCharacters()
{
  int count = 0;
  for (int i=0; i<_length; i++) {
    if (!AmbiguousCharAt(i)) count++;
  }
  return count;
}
 
void Window::ComputeBaseFrequencies(float *freq)
{
  int counts[4];
  int i;
  
  int *dummy;
  for (i=0; i<4; i++) counts[i] = 0;
  int total = 0;
  for (i=0; i<_length; i++) {
    char ch = IndexOfCharAtInReferenceSequence(i);
    if (ch >= 4) continue;
    counts[ch]++; 
    total++;
  }
  for (i=0; i<4; i++) freq[i] = float(counts[i]);
}

void Window::ComputeBaseFrequenciesWithHistory(float **freq, int history_length)
{
  int i;
  int powmorder = int(pow(4,history_length));
  float *basic_freq = new float[4];

  // initialize
  for (i=0; i<4; i++) basic_freq[i]=0;
  for (i=0; i<powmorder; i++) {
    for (int j=0; j<4; j++) {
      freq[i][j] = 0;
    }
  }

  // construct the initial history word
  int mask = powmorder/4;
  int history = 0;
  char ch; int pos = 0; 
  for (i=0; i<history_length; i++) {
    // find the next non-ambiguous char
    while (pos < _length-1 && (ch = IndexOfCharAtInReferenceSequence(pos)) >= 4) pos++;
    // if not found, cannot do anything
    if (pos >= _length-1) {
      printf("Error: end of window encountered before a history word could be read\n");
      exit(1);
    }
    history = history*4 + ch;
  }

  // initial history word constructed
  // now scan the sequence and count bases
  // pos++;
  while (pos < _length) {
    // find the next non-ambiguous char
    while (pos < _length && (ch = IndexOfCharAtInReferenceSequence(pos)) >= 4) pos++;
    // if reached end, we're done counting
    if (pos >= _length) break;
    // else update the appropriate counts
    freq[history][ch]++;
    basic_freq[ch]++;
    // and update the history word also
    if (history_length > 0) {
      history = history%mask;
      history = history*4 + ch;
    }
    pos++;
  }

  // cleanup
  delete [] basic_freq;
}

void Window::Print(bool verbose)
{
  if (_seq == NULL || _start < 0 || _stop < 0) 
    return;
  if (verbose) {
    printf("Window: %d to %d\n",_start,_stop);
    _seq->Print(_start,_stop);
  }
  else {
    printf("%d\t%d\t",_start,_stop);
  }
}

AlignmentNode::AlignmentNode(Sequence *seq, int l1, int r1, Sequence *seq2, int l2, int r2)
{
  _thisSeq = seq;
  _l1 = l1;
  _r1 = r1;
  _otherSeq = seq2;
  _l2 = l2;
  _r2 = r2;
  _next = NULL;
  _prev = NULL;
}

void AlignmentNode::ExtendToLeft(int left)
{
  int newl1 = _l1-left;
  int newl2 = _l2-left;
  if (newl1 < 0 || newl2 < 0) return;
  if (_prev) {
    if (newl1 <= _prev->_r1 || newl2 <= _prev->_r2) return;
  }
  _l1 = newl1;
  _l2 = newl2; 
}

void AlignmentNode::ExtendToRight(int right)
{
  int newr1 = _r1+right;
  int newr2 = _r2+right;
  if (newr1 >= _thisSeq->Length() || newr2 >= _otherSeq->Length()) return;
  if (_next) {
    if (newr1 >= _next->_l1 || newr2 >= _next->_l2) return;
  }
  _r1 = newr1;
  _r2 = newr2; 
}

Alignment::Alignment()
{
  _alist = NULL;
  _num_seq = 0;
}

Alignment::Alignment(int num_seq)
{
  _num_seq = num_seq;
  _alist = new vector<struct AlignmentNode *>[_num_seq];
}

Alignment::Alignment(Alignment *other)
{
  _num_seq = other->_num_seq;
  _alist = new vector<struct AlignmentNode *>[_num_seq];
  if (other->_alist) {
    for (int i=0; i<_num_seq; i++) {
      for (int j=0; j<other->_alist[i].size(); j++) {
	AlignmentNode *ond = (struct AlignmentNode *)(other->_alist[i][j]);
	AlignmentNode *nd = new AlignmentNode(ond->_thisSeq, ond->_l1, ond->_r1, ond->_otherSeq, ond->_l2, ond->_r2);
	AddAlignmentNode(nd);
      }
    }
  }
}

Alignment::~Alignment()
{
  if (_alist) {
    for (int i=0; i<_num_seq; i++) {
      for (int j=0; j<_alist[i].size(); j++) {
	struct AlignmentNode *nd = (struct AlignmentNode *)(_alist[i][j]);
	delete nd;
      }
    }
    delete [] _alist;
    _alist = NULL;
  }
}

void Alignment::AddAlignmentNode(AlignmentNode *nd)
{
  int ind1 = nd->_thisSeq->GetSpeciesIndex();
  int insertbefore = -1;
  for (int i=0; i<_alist[ind1].size(); i++) {
    if (_alist[ind1][i]->_l1 > nd->_l1) {
      insertbefore = i; break;
    }
  }
  if (insertbefore==-1) _alist[ind1].push_back(nd);
  else {
    vector<struct AlignmentNode *>::iterator p = _alist[ind1].begin()+insertbefore;
    _alist[ind1].insert(p,nd);
  }

  // _alist[ind1].push_back(nd);
}

struct AlignmentNode *Alignment::GetAlignmentNodeList(Sequence *seq, int l, int r)
{
  int ind1 = seq->GetSpeciesIndex();
  int num_nodes = _alist[ind1].size();

  struct AlignmentNode *ndlist = NULL;
  struct AlignmentNode *lastnd = NULL;

  for (int i=0; i<num_nodes; i++) {
    struct AlignmentNode *nd = _alist[ind1][i];
    int l1 = nd->_l1; 
    int r1 = nd->_r1;
    //    if (l >= l1 && l <= r1 || r >= l1 && r <= r1) { // add this node to list
    if (!(r < l1 || r1 < l)) {
      if (ndlist == NULL) {
	ndlist = nd;
	lastnd = nd;
	lastnd->_next = NULL;
	lastnd->_prev = NULL;
      }
      else {
	nd->_prev = lastnd;
	lastnd->_next = nd;
	lastnd = nd;
	lastnd->_next = NULL;
      }
    }
  }
  return ndlist;
}

void Alignment::Print(FILE *fp)
{
  for (int i=0; i<_num_seq; i++) {
    for (int j=0; j<_alist[i].size(); j++) {
      struct AlignmentNode *nd = _alist[i][j];
      fprintf(fp,"Seq %d (%d to %d) with Seq %d (%d to %d)\n",nd->_thisSeq->GetSpeciesIndex(), nd->_l1, nd->_r1, nd->_otherSeq->GetSpeciesIndex(), nd->_l2, nd->_r2);
    }
  }
}

void Alignment::PrintAnchs(FILE *fp)
{
  for (int i=0; i<1; i++) {
    for (int j=0; j<_alist[i].size(); j++) {
      struct AlignmentNode *nd = _alist[i][j];
      fprintf(fp,"(%d %d)=(%d %d) 100.00\n",nd->_l1+1, nd->_r1+1, nd->_l2+1, nd->_r2+1);
    }
  }
}

float Alignment::MutationRateInAlignments(Sequence *f, Sequence *c, int &mismatch, int &total)
{
  vector<struct AlignmentNode *> alist = _alist[f->GetSpeciesIndex()];
  mismatch=0; total=0;
  for (int i=0; i<alist.size(); i++) {
    struct AlignmentNode *nd = alist[i];
    if (nd->_thisSeq->GetSpeciesIndex() != f->GetSpeciesIndex()) continue;
    if (nd->_otherSeq->GetSpeciesIndex() != c->GetSpeciesIndex()) continue;
    int l1,l2;
    for (l1=nd->_l1, l2 = nd->_l2; l1 <= nd->_r1 && l2 <= nd->_r2; l1++,l2++) {
      if (f->CharAt(l1) != c->CharAt(l2)) mismatch++;
      total++;
    }
  }
  printf("%d %d\n",mismatch,total);
  return float(mismatch)/float(total);
}

