#include <fstream>
#include <iostream>
#include <stdlib.h>
#include <assert.h>
#include <stdio.h>
#include <math.h>
#include "longRand.h"

using namespace std;

struct Junction {
  short nj;
  short *ms;
  short current; // currently used motif
  double *probs;
}; 

enum BASES {A, C, G, T, N};
const int NB = 5;
char bases[NB] = {'A', 'C', 'G', 'T', 'N'};

const int LL = 50;
int ML, NM;

char BUFF[30];

short IdentifyMotif(char *buff);
char *ReturnMotif(short val);

// take a dataset of potential splice sites and
// infer the maximum likelihood subset
int main(int argc, char *argv[]) {
  if(argc != 5) {
    cerr << "Usage: optimize data_file nr_steps motif_length result_file" << endl;
    exit(0);
  }
  // number of steps to run the optimization
  double nsteps = atof(argv[2]);
  ML = atoi(argv[3]);
  NM = long(pow(double(NB), ML));

  // read in the data

  // data file: 
  // number of splice junctions
  // one line per splice junction
  //    number of possible junctions
  //    list of the possible junctions
  ifstream infile(argv[1]);
  assert(infile);
  int njunc;
  infile >> njunc; // number os splice junctions
  Junction *juncs = new Junction[njunc];
  assert(juncs);
  for(int i = 0; i < njunc; i++) {
    infile >> juncs[i].nj;
    juncs[i].ms = new short[juncs[i].nj];
    assert(juncs[i].ms);
    for(int j = 0; j < juncs[i].nj; j++) {
      infile >> BUFF;
      juncs[i].ms[j] = IdentifyMotif(BUFF);
    }
    juncs[i].probs = new double[juncs[i].nj];
    assert(juncs[i].probs);
  }

  // we read all the motifs

  // get a random number stream
  LongRand rS;

  // initialize
  unsigned long *count = new unsigned long[NM];
  assert(count);
  for(int i = 0; i < NM; i++) {
    count[i] = 0;
  }

  // assign a motif to each of the junctions
  for(int i = 0; i < njunc; i++) {
    unsigned long tmp = rS.Uniform(double(juncs[i].nj));
    juncs[i].current = juncs[i].ms[tmp];
    count[juncs[i].current]++;
  }

  // now do the random walk
  unsigned long which;
  double r;
  double beta = 1.0;
  double beta_final = 4.0;
  int n_transient = (int) (nsteps/10);
  int n_deep_quench = (int) (nsteps/100);
  double beta_step = (beta_final-1.0)/((double) (nsteps-n_transient-n_deep_quench));
  //double pstep = 1.0;
  for(double steps = 0; steps < nsteps; steps += 1.0) {

    /**increase inverse temperature after initial phase ****/
    if(steps > n_transient){
      beta += beta_step;
    }
    /**if beta has run over beta_final we are in deep quench regime***/
    if(beta > beta_final){
      beta = 25.0;
    }
    
    // pick a random junction and reasssign it
    /*    if(1000000 * int(steps/1000000) == steps) {
      cerr <<"Step " << steps << endl;
      }*/
    which = rS.Uniform(njunc);
    while(juncs[which].nj < 2) {
      which = rS.Uniform(njunc);
    }
    // subtract its count from current counts
    count[juncs[which].current]--;
    // try to reassign the junction
    double maxcount = 0;
    for(int j = 0; j < juncs[which].nj; j++) {
      juncs[which].probs[j] = count[juncs[which].ms[j]];
      if(juncs[which].probs[j] > maxcount){
	maxcount = juncs[which].probs[j];
      }
    }
    if(maxcount > 0) {
      /***normalize relative to junction with maximal count and raise to power beta****/
      for(int j = 0; j< juncs[which].nj;j++){
	juncs[which].probs[j] = pow(juncs[which].probs[j]/maxcount,beta);
      }
      /**now get the cumulative distribution***/
      for(int j = 1; j < juncs[which].nj; j++) {
	juncs[which].probs[j] = juncs[which].probs[j] + juncs[which].probs[j-1];
      }
      
      // throw a random number and determine the new boundary
      r = (juncs[which].probs[juncs[which].nj-1])*rS.Next();
      short k = juncs[which].nj-1;
      while(r < juncs[which].probs[k] && k >= 0) {
	k--;
      }
      if(k < 0 || r > juncs[which].probs[k]) {
	k++;
      }
      juncs[which].current = juncs[which].ms[k];
      count[juncs[which].current]++;
    }
    else {
      juncs[which].current = rS.Uniform(juncs[which].nj);
      count[juncs[which].current]++;
    }
  }
    
  // report the junctions
  ofstream of(argv[4]);
  assert(of);
  for (int i = 0; i < njunc; i++) {
    char *v = ReturnMotif(juncs[i].current);
    cout << v << endl;
    delete[] v;
  }

  // print out the junction frequencies
  char *outline = new char[LL];
  assert(outline);
  for(short i = 0; i < NM; i++) {
    if(count[i] > 0) {
      char *v = ReturnMotif(i);
      sprintf(outline, "%s\t%1.10lf\0", v, double(count[i])/njunc);
      of << outline << endl;
      delete[] v;
    }
  }
  of.close();
}

short IdentifyMotif(char *buff) {
  short val = 0;
  for(int i = 0; i < ML; i++) {
    short c;
    switch(buff[i]) {
    case 'A':
      c = A;
      break;
    case 'C':
      c = C;
      break;
    case 'G':
      c = G;
      break;
    case 'T':
      c = T;
      break;
    default:
      c = N;
      break;
    }
    val = val * NB + c;
  }
  return val;
}

char *ReturnMotif (short num) {
  char *b = new char[ML+1];
  assert(b);
  for(int i = 0; i < ML; i++) {
    b[ML-i-1] = bases[num%NB];
    num /= NB;
  }
  b[ML] = '\0';
  return b;
}
