/*****************************************************************
        Copyright by Rockefeller University,
can not be reproduced or distributed without written permission of
copyright holder.  Version of October 2003.

Written by Saurabh Sinha (contact person), Erik van Nimwegen, and 
Eric Siggia.

The program stubb (and its relatives) implement an algorithm for
finding likely cis-regulatory modules, described in the following
paper:
"A Probabilistic Method to Detect Regulatory Modules"
by Saurabh Sinha, Erik van Nimwegen and Eric Siggia. 
Eleventh International Conference on Intelligent Systems for
Molecular Biology, Brisbane, Australia, July 2003, pg 292-301.

The file sample/gap_wtmx that comes with this distribution includes 
a sample set of transcription factor weight matrices (PWM's) that 
were reported in :
"Computational detection of genomic cis-regulatory modules applied
to body patterning in the early Drosophila embryo"
by N. Rajewsky, M. Vergassola, U. Gaul and E. Siggia.
BMC Bioinformatics 3 (30) 2002.
******************************************************************/

#ifndef _parameters_h_
#define _parameters_h_

#include "typedefs.h"
#include "sequence.h"

#include <vector>
using namespace std;

#ifndef MARKOV_ORDER
#define MARKOV_ORDER 2
#endif

#ifndef CONTEXT_SIZE
#define CONTEXT_SIZE 10
#endif

#define BKG_FORWARD_ONLY 1
#define DEFAULT_BOTH_ORIENTATIONS true
#define ALMOSTONE 0.9
#define MOTIF_COUNT_THRESHOLD 1.0
#define DEFAULT_MOTIF_OCCURRENCE_THRESHOLD 0.1
#define DEFAULT_FEN_THRESHOLD 10
#define SMALL_MOTIF_OCCURRENCE_THRESHOLD 0.01
#define MS_MOTIF_COUNT_THRESHOLD 2.0
#define PROBABILITY_CACHE_SIZE 25000
#define abs(x) (x>0?x:-x)

extern WtMx *global_background;

class Parameters {
 protected:
  vector<Window*>*_windows;
  WtMxCollection *_wmc;
  WtMx           *_bkgwm;
  int            *_wm_len;                 // for speedup
  int            _numWM;                   // this is needed because _wmc may be modified after 
                                           // this object has been used
                                           // but before it has been destroyed 
  bool           **_free_emission_probabilities;
  
  int            _initialbias;
  bool           _is_trained;
  bool           _is_initialized;
  
  DTYPE          _free_energy;
  DTYPE          _free_energy_differential;// Fb - F
  DTYPE          _free_energy_perlength;
  int            _num_iterations;

  struct Phylogeny {                       // this structure will have all the phylogenetic information 
    float *_mu;                            // needed by the parameters object to compute probabilities
    int   _numSpecies;                     
    Phylogeny();
    DTYPE  ComputeProbability(WtMx *wm, int index, bool reverse_orientation, char *arrayofindex, int *history_words = NULL);
    DTYPE  ComputeProbabilityGivenAncestor(WtMx *wm, int index,  int orientation, int ancestor, char *arrayofchar, int *history = NULL);                      // orientation: 0 for fwd, 1 for rev, 2 for both
  };
    
  static struct Phylogeny  *_phy;          // to allow for multiple sequences

  static const float SMALL_FLOAT = 1E-10; // when is a number too small to divide by 
  static const double SMALL_DOUBLE = 1E-200;
  static const float INF_FREE_ENERGY = 1000000;

  char ReverseChar(char ch);                  // utility function

  struct probabilitycache {                   // may be associated with entire sequence or a particular window
    DTYPE  **_prob;                           // prob[j][k] is for (j==wm)(k==position in seq)
    int _start;
    int _length;
    WtMx **_wms;
    int _numWM;
    Window *_associatedCurrentWindow;         // if cache is for a particular window, which one is it

    probabilitycache();                       // regular constructor
    probabilitycache(const probabilitycache &pc);   // copy constructor
    probabilitycache& operator=(const probabilitycache &pc);   // copy constructor
    ~probabilitycache();                      // destructor
    void Destroy();
    private:
    void Copy(const probabilitycache &pc);
  };

  static const float ALMOST_ONE = ALMOSTONE;           // used to initialize parameters to an extreme point
  static const float CONTEXT_WIDTH_FACTOR = CONTEXT_SIZE;        // used to create background context
  
 public:
  static struct probabilitycache *_currentwindowcache; // the vector of caches of subsequence probabilities for associated windows

  Parameters();
  void   FreeEmissionProbabilities(int wmindex, int offset = -1);
  void   FixEmissionProbabilities(int wmindex, int offset = -1);
  bool   IsEmissionProbabilityFree(int wmindex, int offset = -1);
  float  Free_Energy_Differential();
  float  Free_Energy();
  void   Scale_Free_Energy_Differential(float scalefactor);

  int    InitialBias();
  bool   IsTrained();
  bool   IsInitialized();
  int    NumIterations();
  DTYPE  EvaluateFreeEnergyBackground(bool both_orientations=DEFAULT_BOTH_ORIENTATIONS);
  static WtMx *TrainWtMx(Window *context);
  void   TrainBackground(Window *context=NULL);
  int    BackgroundIndex();
  int    BackgroundIndex(WtMxCollection *wmc);
  void   SetBackground(float *bkg);
  void   GetBackground(float *&bkg);
  void   SetBackground(WtMx  *bkg);
  void   GetBackground(WtMx  *&bkg);
  void   CacheBackgroundProbabilities();
  void   PrintBackground(FILE *fp);
  int    NumWM();
  void   PrintWM(FILE *fp, int i);
  void   PrintPID(FILE *fp);
  DTYPE  PID();

  static int GetMarkovOrder() { return MARKOV_ORDER; }
  static int GetBackgroundOrientation() { 
#ifdef BKG_FORWARD_ONLY 
	return BKG_FORWARD_ONLY;
#else
 	return 0;
#endif 
  }
  static float  GetAlmostOne() { return ALMOST_ONE; }
  static float  GetContextWidthFactor() { return CONTEXT_WIDTH_FACTOR; }

  static void SetPhylogeny(float *mu, int numSpecies);
  static const float *GetPhylogeny(int &numSpecies);

  struct probabilitycache *CacheWindowBackgroundProbabilities(Window *win, WtMx *wm);
  static void CacheSubsequenceProbabilities(Sequence *seq, WtMxCollection *wmc, int start, int cache_length, bool lookAtAlignments = true);
  static void DeleteCacheSubsequenceProbabilities(Sequence *seq);

  static DTYPE  ComputeSequenceProbability(Window *win, int start, int stop, WtMx *wm, bool both_orientations=DEFAULT_BOTH_ORIENTATIONS);


  virtual void Initialize(vector<Window *> *wl, WtMxCollection *wmc, Parameters *init) = 0;
  virtual void Initialize(vector<Window *> *wl, WtMxCollection *wmc, int initialbias = -1) = 0;
  virtual void Train(bool differential=true) = 0;
  virtual void Print(FILE *fp, bool verbose = false) = 0;
  virtual void PrintProbabilities(FILE *fp, bool verbose=false) = 0;
  virtual void PrintAverageCounts(FILE *fp, bool verbose=false) = 0;
  virtual char *CreateSequence(int length, int seed, bool verbose=false) = 0;
};

// HMM with no history of last planted motif 
class Parameters_H0 : public Parameters {
  float  *_pi;                       // probability parameters 
  float  *_oldpi;                    // values from previous iteration of optimization

  DTYPE  **_Ai;                      // 
  DTYPE  **_Ail_window;              // 
  DTYPE  *_alpha;                    // forward variables
  DTYPE  *_beta;                     // backward variables
  DTYPE  *_c;                        // scaling factors
  DTYPE  **_cij;                     // Product_i^j {c}
  DTYPE  *_fringe_corrections;       // the fringe correction factors
  int    _max_window_length;         // used for an optimization, set in Initialize()
  Window *_currentWindow;            // as above, used for an optimization

  static const float THRESHOLD = 1e-4;           // used to terminate training
  static const float RELENT_THRESHOLD = 0.00001; // used to decide upon perturbing parameters
  static const int CHECK_ITERATION_THRESHOLD = 20;
  static const int MAX_TRAINING_ITERATIONS = 100;
  static const float CHECK_FEN_THRESHOLD = 0.1;

  DTYPE  EvaluateFreeEnergy();
  DTYPE  FringeCorrectionFactor(int index);  // multiply Ail by this factor to get correct value; index is the window index in the vector
  void   PrepareForUpdate();
  void   UpdateEmissionProbabilities();
  void   UpdateTransitionProbabilities();
  void   Update();                           // update the parameter values
  void   Revert();                           // to go back to previous parameter values
  void   Forward();                          // Forward algorithm (Baum-Welch)
  void   Backward();                         // Backward algorithm (Baum-Welch)
  void   Copy(const Parameters_H0 &p);       // helper member for copy constructor and assignment operator
  void   Destroy();                          // helper member for copy constructor, assignment, and destructor
  DTYPE  Norm_of_parameter_difference();     

 public:
  Parameters_H0();                                     // constructor
  ~Parameters_H0();                                    // destructor
  Parameters_H0(const Parameters_H0 &p);               // copy constructor
  Parameters_H0& operator=(const Parameters_H0 &p);    // assignment operator
  Parameters_H0(WtMxCollection *wmc, int extreme);     // create a dummy set of parameters to begin training from
                                                       // if extreme >= 0, paramters biased to w_i; else uniform

  virtual void Initialize(vector<Window *> *wl, WtMxCollection *wmc, Parameters *init);
  virtual void Initialize(vector<Window *> *wl, WtMxCollection *wmc, int initialbias = -1);
  virtual void DowngradeInitialize(vector<Window *> *wl, WtMxCollection *wmc, class Parameters_H01 *init);
  virtual void Train(bool differential=true);
  void TrainWithFixedParameters(bool differential=true);
  virtual void Print(FILE *fp, bool verbose = false);
  virtual void PrintProbabilities(FILE *fp, bool verbose=false);
  virtual void PrintAverageCounts(FILE *fp, bool verbose=false);
  void PrintProfile(FILE *prof, FILE *dict, float occurrence_threshold = DEFAULT_MOTIF_OCCURRENCE_THRESHOLD);
  virtual char*CreateSequence(int length, int seed, bool verbose=false);

  DTYPE   ComputeAverageCount(int i);
  DTYPE   ComputeExpectedAverageCount(int i,DTYPE  *&expectations);
  DTYPE   ComputeVarianceOfCount(int i, DTYPE  *expectations);

#ifdef _CYCLIC_WINDOWS
  float   **GetLastMotifs();
  void    DeleteSpaceForLastMotifs(float **initial);
#endif

  int     MaximumLeftOverlap(AlignmentNode *al, float occurrence_threshold = DEFAULT_MOTIF_OCCURRENCE_THRESHOLD);
  int     MaximumRightOverlap(AlignmentNode *al, float occurrence_threshold = DEFAULT_MOTIF_OCCURRENCE_THRESHOLD);


  void    SetParameters(DTYPE  *p);
  DTYPE   GetParameter(int index);
  static  float GetThreshold() { return THRESHOLD; }

  friend class Parameters_H1;
  friend class Parameters_H01;
};


// HMM with history of last planted non-background motif
class Parameters_H1 : public Parameters {
  float  **_pij;                    // transition probability parameters 
  float  **_oldpij;                 // values from previous iteration of optimization
#ifdef _CYCLIC_WINDOWS
  float  **_initial;                 // initial probabilities of HMM, for each window
#endif

  DTYPE  ***_Aij;                    // 
  DTYPE  ***_Aijl_window;
  DTYPE  **_alpha;                   // forward variables
  DTYPE  **_beta;                    // backward variables
  DTYPE  *_c;                        // scaling factors
  DTYPE  **_cij;                     // Product_i^j {c}
  DTYPE  *_fringe_corrections;       // the fringe correction factors
  int    _max_window_length;         // used for an optimization, set in Initialize()
  Window *_currentWindow;            // as above, used for an optimization

  static const float THRESHOLD = 1e-4;           // used to terminate training
  static const float ALPHATHRESHOLD = 1e-4;      // used to detect saturation of "alpha" in VarianceComputation
  static const int CHECK_ITERATION_THRESHOLD = 20;
  static const int MAX_TRAINING_ITERATIONS = 100;
  static const float CHECK_FEN_THRESHOLD = 0.1;

  DTYPE          EvaluateFreeEnergy();
  DTYPE          FringeCorrectionFactor(int index);  // multiply Aijl by this factor to get correct value; index is the window index in the vector
  void           PrepareForUpdate();
  void           UpdateEmissionProbabilities();
  virtual void   UpdateTransitionProbabilities();
  virtual void   Update();                           // update the parameter values
  virtual void   Revert();                           // to go back to previous parameter values
  void           Forward();                          // Forward algorithm (Baum-Welch)
  void           Backward();                         // Backward algorithm (Baum-Welch)
  virtual void   Copy(const Parameters_H1 &p);       // helper member for copy constructor and assignment operator
  virtual void   Destroy();                          // helper member for copy constructor, assignment, destructor
  DTYPE          Norm_of_parameter_difference();     
  DTYPE          FringeCorrectionFactorHelper(Window *window);// helper member for initializing fringe_corrections

 public:
  Parameters_H1();                                     // constructor
  ~Parameters_H1();                                    // destructor
  Parameters_H1(const Parameters_H1 &p);               // copy constructor
  Parameters_H1& operator=(const Parameters_H1 &p);    // assignment operator
  Parameters_H1(WtMxCollection *wmc, int extreme);     // create a dummy set of parameters to begin training from
                                                       // if extreme >= 0, paramters biased to w_i; else uniform

  virtual void   Initialize(vector<Window *> *wl, WtMxCollection *wmc, Parameters *init);
  virtual void   Initialize(vector<Window *> *wl, WtMxCollection *wmc, int initialbias = -1);
  virtual void   Train(bool differential=true);
  virtual void   Print(FILE *fp, bool verbose = false);
  virtual void   PrintProbabilities(FILE *fp, bool verbose=false);
  virtual void   PrintAverageCounts(FILE *fp, bool verbose=false);
  void           PrintProfile(FILE *prof, FILE *dict,float occurrence_threshold = DEFAULT_MOTIF_OCCURRENCE_THRESHOLD);
  virtual char   *CreateSequence(int length, int seed, bool verbose=false);

  virtual void   UpgradeInitialize(vector<Window *> *wl, WtMxCollection *wmc, Parameters_H0 *init);
#ifdef _CYCLIC_WINDOWS
  void           SetInitial(float **initial);          
#endif
  void           DoDynamicProgramming();               // Update alpha,beta,Aijl values based on current parameters
  virtual void   ForceCorrelation(int i, int j, float strength);
  DTYPE          ComputeAverageCount(int i, int j);    // how many times is i followed by j, averaged over parses
  DTYPE          ComputeExpectedAverageCount(int i, int j, Parameters_H0* param, DTYPE  *&expectations); 
  DTYPE          ComputeVarianceOfCount(int i, int j, Parameters_H0* param, DTYPE  *expectations);

#ifdef _CYCLIC_WINDOWS
  void           UpdateInitial();
#endif

  static  float  GetThreshold() { return THRESHOLD; }

  friend class Parameters_H0;
  friend class Parameters_H01;
};

// HMM with mixed order (0,1)
class Parameters_H01 : public Parameters_H1 {
  float *_pi;            // the uncorrelated motif probabilities
  float *_oldpi;         // previous values of pi
  bool **_Cij;           // is motif i correlated with motif j (when j follows i)

  void AdjustProbabilityParameters(); // changes the _pij (j not in Ci) by scaling appropirately
                                      // Note: the pij's must always have the property that they sum 
                                      // to 1, and that the uncorr. pijs are scaled versions of uncorr. pj's
                                      // Whenver the correlation structure OR the probability parameters are
                                      // changed, these properties have to be re-enforced.

  virtual void   Copy(const Parameters_H01 &p);      // helper member for copy constructor and assignment operator
  virtual void   Destroy();                          // helper member for copy constructor, assignment, destructor

  DDTYPE  GradF(DDTYPE  *u, int l, int **C_i_bar, int *sizeCibar, DTYPE  **Aij, int N);
  DDTYPE  HessianF(DDTYPE  *u, int l, int m, int **C_i_bar, int *sizeCibar, DTYPE  **Aij, int N);

  static const float  INF_LEAST_SQUARES = 1000000;
  static const int    MAX_NEWTON_ITERATIONS = 100;
  static const DDTYPE  LAMBDA_THRESHOLD = 1e-10;       // at least one singular value must be greater than this
  static const DDTYPE  LEAST_SQUARES_THRESHOLD = 1e-8;

 public:
  Parameters_H01();                                     // constructor
  ~Parameters_H01();                                    // destructor
  Parameters_H01(const Parameters_H01 &p);              // copy constructor
  Parameters_H01& operator=(const Parameters_H01 &p);   // assignment operator

  /* this is the new function in this class */
  void         EnableCorrelation(int *flatlist, int size);
  void         Train(char *corr_file, bool differential = true);

  /* this function hasnt changed, but needs to be declared for this class */
  virtual void Train(bool differential=true);

  /* following functions have changed from the parent class Parameters_H1 */
  virtual void UpdateTransitionProbabilities();
  virtual void Update();                   // update the parameter values
  virtual void Revert();                   // to go back to previous parameter values
  virtual void UpgradeInitialize(vector<Window *> *wl, WtMxCollection *wmc, Parameters_H0 *init);
  virtual void PrintProbabilities(FILE *fp, bool verbose=false);

  /* following are the unsupported functions for this class */
  virtual void ForceCorrelation(int i, int j, float strength);
  virtual char *CreateSequence(int length, int seed, bool verbose=false);

  friend class Parameters_H0;
};

#endif









