/*****************************************************************
        Copyright by Rockefeller University,
can not be reproduced or distributed without written permission of
copyright holder.  Version of October 2003.

Written by Saurabh Sinha (contact person), Erik van Nimwegen, and 
Eric Siggia.

The program stubb (and its relatives) implement an algorithm for
finding likely cis-regulatory modules, described in the following
paper:
"A Probabilistic Method to Detect Regulatory Modules"
by Saurabh Sinha, Erik van Nimwegen and Eric Siggia. 
Eleventh International Conference on Intelligent Systems for
Molecular Biology, Brisbane, Australia, July 2003, pg 292-301.

The file sample/gap_wtmx that comes with this distribution includes 
a sample set of transcription factor weight matrices (PWM's) that 
were reported in :
"Computational detection of genomic cis-regulatory modules applied
to body patterning in the early Drosophila embryo"
by N. Rajewsky, M. Vergassola, U. Gaul and E. Siggia.
BMC Bioinformatics 3 (30) 2002.
******************************************************************/

#include <stdio.h>
#include <assert.h>
#include <math.h>

#include "util.h"
#include "sequence.h"
#include "parameters.h"
#include "fastafile.h"

int globalid;
WtMx *global_background;

void PruneMotifsAndTrain(int totalLen, int windowsize, vector<Window *> *wl, WtMxCollection *wmc, Parameters_H0 *param0, WtMxCollection *wmc_prune, Parameters_H0 *&param0_prune);

main(int argc, char **argv)
{
  if (argc < 6) {
    printf("usage: %s <sequencefile> <wtmxfile> <windowsize> <shiftsize> <fitprobs_file> [-od <output_dir>] [-b <background file>] [-ft <energy thresold for printing>] [-ot <motif occurrence threshold for printing]\n",argv[0]);
    exit(1);
  }
  globalid = 0;
  global_background = NULL;

  // Read in the sequence(s)
  char *fastafile = argv[1];
  FastaFile sequences;
  sequences.ReadFasta(fastafile);

  // Read in the weight matrices
  char *wmcfile = argv[2];
  WtMxCollection wmc(wmcfile);
  int numWM = wmc.Size();

  // Read in the other two compulsory parameters
  int windowsize = atoi(argv[3]);
  int shiftsize = atoi(argv[4]);
  char *fitprobsfile = argv[5];
  FILE *fitprobsfp = fopen(fitprobsfile,"r");
  DTYPE *fitprobs = new DTYPE[numWM+1];
  for (int wc=0; wc <= numWM; wc++) {
    if (fscanf(fitprobsfp,"%f ",&(fitprobs[wc]))<1) {
      printf("Not enough probabilities in %s\n",fitprobsfile);
      exit(1);
    }
  }
  // Read in the optional arguments
  int argbase = 6;
  struct Options *opt = ReadOptionalArguments(argbase, argc, argv);
  
  if (opt->bkg_file != NULL) {
    Sequence *bkg_seq = new Sequence(opt->bkg_file);
    Window *bkg_window = new Window(bkg_seq,0,bkg_seq->Length()-1);
    global_background = Parameters::TrainWtMx(bkg_window);
    delete bkg_window;
    delete bkg_seq;
  }
  
  // Prepare for various printing
  FILE *prof = OpenProfile(fastafile, opt->output_dir);
  FILE *dict = OpenDictionary(fastafile, opt->output_dir);
  FILE *ener = OpenOutput (fastafile, opt->output_dir);
  PrintParameters(fastafile,wmcfile,windowsize, shiftsize, opt);
  
  for (int seqnum=0; seqnum < sequences.Size(); seqnum++) {
    Sequence *seq = sequences[seqnum];
    int seq_len = seq->Length();
    if (seq_len < 200) continue;
    
    // Declare that the matrices in wmc are not going to be modified, and cache the probabilities
#ifdef _OPTIMIZE_CACHESEQUENCEPROBABILITIES
    int cache_expires_at = -1;
#endif
    
    // now start iterating through the windows
    WindowIterator wi(seq);
    bool did_begin = false;
#ifdef _EXIST_SMALL_SEQUENCES
    for (did_begin = wi.Begin(min(windowsize,seq_len),shiftsize); did_begin && !wi.End(); wi.Next()) {
#else 
    for (did_begin = wi.Begin(windowsize,shiftsize); did_begin && !wi.End(); wi.Next()) {
#endif
      vector<Window *> *wl = new vector<Window *>;
      wi.CurrentWindowList(wl);
      
      // verify that the windows are good
      if (wl->size() < 1) {
	delete wl;
	continue;
      }
      
      int totalLen = 0;
      for (int windex=0; windex<wl->size(); windex++) {
	Window *win = (*wl)[windex];
	totalLen += win->Length();
      }
      
      if (!IsValidWindowList(wl)) {
	int startpos = (*wl)[0]->Start();
	for (int windex=0; windex<wl->size(); windex++) {
	  Window *win = (*wl)[windex];
	  delete win;
	}
	delete wl;
	fprintf(ener,"%d\t0.000000\t0.000000\t%d\t0\n",startpos,totalLen);
	continue;
      }
      
#ifdef _OPTIMIZE_CACHESEQUENCEPROBABILITIES
      int current_position = (*wl)[0]->Start();
      if (current_position > cache_expires_at) {
	Parameters::DeleteCacheSubsequenceProbabilities(seq);
#ifdef _WTMX_BIAS
	Parameters::CacheSubsequenceProbabilities(seq,&wmc,current_position,windowsize);
	cache_expires_at = current_position + windowsize - 1;
#else 
	Parameters::CacheSubsequenceProbabilities(seq,&wmc,current_position,PROBABILITY_CACHE_SIZE);
	cache_expires_at = current_position + PROBABILITY_CACHE_SIZE - 1;
#endif
      }
#endif
      // train with all matrices
      Parameters_H0 *param0 = new Parameters_H0;
      int bkgIndex = param0->BackgroundIndex(&wmc);
      param0->Initialize(wl,&wmc,bkgIndex);
      param0->SetParameters(fitprobs);
      param0->TrainWithFixedParameters();
      
      param0->Print(ener);
      if (param0->Free_Energy_Differential() > opt->fen_threshold)
	param0->PrintProfile(prof,dict,opt->motif_occurrence_threshold);    
      delete param0;
      
      // clean up
      for (int windex=0; windex<wl->size(); windex++) {
	Window *win =  (*wl)[windex];
	delete win;
      }
      delete wl; 	
    }
    
#ifdef _OPTIMIZE_CACHESEQUENCEPROBABILITIES
    Parameters::DeleteCacheSubsequenceProbabilities(seq);
#endif
  }
    
  fclose(ener);
  fclose(dict);
  fclose(prof);
  delete opt;
}

void PruneMotifsAndTrain(int totalLen, int windowsize, vector<Window *> *wl, WtMxCollection *wmc, Parameters_H0 *param0, WtMxCollection *wmc_prune, Parameters_H0 *&param0_prune)
{
#ifdef _PRUNE_MOTIFS
  // pick out the good matrices
#ifndef _OPTIMIZE_WMINDEX
  printf("Error: OPTIMIZE_WMINDEX must be defined\n");
  exit(1);
#endif
  DTYPE count_threshold = totalLen*MOTIF_COUNT_THRESHOLD/float(windowsize);

  int numWM = wmc->Size();
  DTYPE *trained_pi = new DTYPE[numWM];
  int numWM_prune = 0;
  for (int wmindex = 0; wmindex < numWM; wmindex++) {
    DTYPE count = param0->ComputeAverageCount(wmindex);
    if (count > count_threshold) {
      wmc_prune->Add(new WtMx(wmc->WM(wmindex)));
      trained_pi[numWM_prune] = param0->GetParameter(wmindex);
      numWM_prune++;
    }
  }
  if (numWM_prune != wmc_prune->Size()) {
    printf("Error: inconsistent number of matrices in wmc_prune\n");
    exit(1);
  }
  if (wmc_prune->Size() == 0) {
    param0_prune = NULL;
    delete [] trained_pi;
    return;
  }
  delete param0;

  // now train again and print
  param0_prune = new Parameters_H0;
  int bkgIndex = param0_prune->BackgroundIndex(wmc_prune);
  param0_prune->Initialize(wl,wmc_prune,bkgIndex);
  DTYPE trained_pi_sum = 0;
  for (int wmindex=0; wmindex < numWM_prune; wmindex++) {
    trained_pi_sum += trained_pi[wmindex];
  }
  DTYPE bkgprob_prune = 1 - trained_pi_sum;
  if (bkgprob_prune < 0) {
    printf("Error: negative background probability after training !\n");
    exit(1);
  }
  trained_pi[bkgIndex] = bkgprob_prune;
  param0_prune->SetParameters(trained_pi);
  param0_prune->Train();
  delete [] trained_pi;
  return;
#else
  param0_prune = NULL;
  return;
#endif
}
