// --*- C++ -*------x---------------------------------------------------------
// $Id: alidisc.cc,v 1.2 2008/07/17 11:53:36 bindewae Exp $
//
// Program:         - 
//
// Author:          Eckart Bindewald
//
// Project name:    -
//
// Date:            $Date: 2008/07/17 11:53:36 $
//
// Description:     
// 
// -----------------x-------------------x-------------------x-----------------

#include <iostream>
#include <fstream>
#include <string>
#include <Vec.h>
#include <debug.h>
#include <GetArg.h>
#include <FileName.h>
#include <CompensationScorer.h>
#include <SequenceAlignment.h>
#include <SimpleSequenceAlignment.h>
#include <math.h>
#include <iomanip>
#include <vectornumerics.h>

void
helpOutput(ostream& os)
{
  os << "usage: alidisc -i fastafile1 fastafile2 [-s targetsequence][-n id][--mover 0|1][--start pos][--stop pos]" << endl;
  os << "other options: " << endl
     << "--alphabet CHARACTERS" << endl
     << "--pos position  : string like -3,5-9,13,16,19-"
     << endl;
  
}

/** output of command line parameter with which the program was called. */
void
parameterOutput(ostream& os, int argc, char** argv)
{
  for (int i = 0; i < argc; i++)
    {
      os << argv[i] << " ";
    }
  os << endl;
}

/** returns likelyhood ratio test for observing certain charachter seqChar (which can also be a gap character!) */
double
alignmentColumnPreference(const string& s1,
			  const string& s2,
			  char seqChar,
			  const string& alphabet,
			  double pseudo)
{
  // cout << "Starting alignmentColumnPreference!" << endl;
  double f1 = 0.0; // CompensationScorer::frequency3(s1, seqChar, 1, alphabet.size());
  double f2 = 0.0; // CompensationScorer::frequency3(s2, seqChar, 1, alphabet.size());
  bool found = false;
  for (unsigned int i = 0; i < alphabet.size(); ++i) { 
    if (alphabet[i] == seqChar) {
      found = true;
      break;
    }
  }
  if (!found) {
    return 1.0; // not in alphabet!
  }
  for (unsigned int i = 0; i < s1.size(); ++i) { 
    if (s1[i] == seqChar) {
      ++f1;
    }
  }
  for (unsigned int i = 0; i < s2.size(); ++i) { 
    if (s2[i] == seqChar) {
      ++f2;
    }
  }
  f1 = (f1 + pseudo) / (s1.size() + pseudo);
  f2 = (f2 + pseudo) / (s2.size() + pseudo);
  ASSERT(f1 > 0.0);
  ASSERT(f2 > 0.0);
  double result = f1/f2;
  // cout << "Finished alignmentColumnPreference!" << endl;
  return result;
}

/*
double
alignmentColumnPreference(const string& s1a,
			  const string& s1b,
			  const string& s2a,
			  const string& s2b,
			  char seqChar1,
			  char seqChar2)
{
  // cout << "Starting alignmentColumnPreference!" << endl;
  //  double f1 = CompensationScorer::frequency3(s1, seqChar, 1, alphabet.size());
  // double f2 = CompensationScorer::frequency3(s2, seqChar, 1, alphabet.size());
  unsigned int sum1 = 1; // use pseudocount
  unsigned int sum2 = 1; // use pseudocount
  for (unsigned int i = 0; i < s1a.size(); ++i) { 
    if ((s1a[i] == seqChar1) && (s1b[i] == seqChar2)) {
      ++sum1;
    }
    if ((s2a[i] == seqChar1) && (s2b[i] == seqChar2)) {
      ++sum2;
    }
  }
  double result = (static_cast<double>(sum1)/s1a.size()) 
    / (static_cast<double>(sum2)/s2a.size());
  // cout << "Finished alignmentColumnPreference!" << endl;
  return result;
}
*/

Vec<Vec<double> >
compareAlignments(const SequenceAlignment& ali1,
		  const SequenceAlignment& ali2,
		  const string& alphabet,
		  double pseudo)
{
  PRECOND(ali1.getLength() == ali2.getLength());
  // PRECOND(seq.size() == ali1.getLength());
  // cout << "Starting compareAlignments!" << endl;
  Vec<Vec<double> > avgRow(alphabet.size(), Vec<double>(ali1.getLength(), 0.0));
  for (unsigned int i = 0; i < ali1.getLength(); ++i) {
    string s1 = ali1.getColumn(i);
    string s2 = ali2.getColumn(i);
    for (unsigned int j = 0; j < alphabet.size(); ++j) {
      double ltr = alignmentColumnPreference(s1, s2, alphabet[j], alphabet, pseudo);
      avgRow[j][i] = log10(ltr);
    }
  }
  // cout << "Finished compareAlignments!" << endl;
  return avgRow;
}


/*
Vec<Vec<double> >
compareAlignmentsDuos(const SequenceAlignment& ali1,
		  const SequenceAlignment& ali2,
		  const string& alphabet)
{
  PRECOND(ali1.getLength() == ali2.getLength());
  // PRECOND(seq.size() == ali1.getLength());
  // cout << "Starting compareAlignments!" << endl;
  Vec<Vec<double> > avgRow(alphabet.size()*alphabet.size(), Vec<double>(ali1.getLength(), 0.0));
  for (unsigned int i = 1; i < ali1.getLength(); ++i) {
    string s1a = ali1.getColumn(i-1);
    string s1b = ali1.getColumn(i);
    string s2a = ali2.getColumn(i-1);
    string s2b = ali2.getColumn(i);
    for (unsigned int j = 0; j < alphabet.size(); ++j) {
      for (unsigned int k = 0; k < alphabet.size(); ++k) {
	double ltr = alignmentColumnPreference(s1a, s1b, s2a, s2b, alphabet[j], alphabet[k]);
	avgRow[alphabet.size()*j+k][i] = log10(ltr);
      }
    }
  }
  // cout << "Finished compareAlignments!" << endl;
  return avgRow;
}
*/

/** scores sequence according to single-nucleotide preferences */
double
scoreSequence(const string& seq,
	      const SequenceAlignment& ali1,
	      const SequenceAlignment& ali2,
	      const string& alphabet,
	      unsigned int startPos,
	      unsigned int stopPos,
	      double pseudo) 
{
  // cout << "Starting scoreSequence(1): " << alphabet << " " << startPos << " " << stopPos << " " 
  // << seq;
  double score = 0.0;
  for (unsigned int i = startPos; i <= stopPos; ++i) {
    double term = log10(alignmentColumnPreference(ali1.getColumn(i),
					     ali2.getColumn(i),
						  seq[i], alphabet,
						  pseudo));
    // cout << "Term of position " << (i+1) << " : " << term << endl;
    score += term;
  }
  // cout << " result: " << score;
  return score;
}

/** scores sequence according to single-nucleotide preferences */
double
scoreSequence(const string& seq,
	      const SequenceAlignment& ali1,
	      const SequenceAlignment& ali2,
	      const string& alphabet,
	      const Vec<unsigned int>& pos,
	      double pseudo)
{
//   cout << "Starting scoreSequence(2): " << alphabet << " " << pos << " "
//        << seq;
  double score = 0.0;
  for (unsigned int i = 0; i < pos.size(); ++i) {
    score += log10(alignmentColumnPreference(ali1.getColumn(pos[i]),
					     ali2.getColumn(pos[i]),
					     seq[pos[i]], alphabet,
					     pseudo));
  }
  // cout << " result: " << score;
  return score;
}


/** removes too similar sequences */
void
removeTooSimilar(const string& targetSeq, 
		 SequenceAlignment& ali, 
		 double seqSimLimit)
{
  for (int i = static_cast<int>(ali.size())-1; i >= 0; --i) {
    unsigned int numIdent = SequenceAlignment::countIdentical(targetSeq,
			     ali.getSequence(static_cast<unsigned int>(i)), "XN");
    double seqSim = static_cast<double>(numIdent) / targetSeq.size();
    if (seqSim >= seqSimLimit) {
      ali.removeSequence(static_cast<unsigned int>(i));
    }
  }
}

int
main(int argc, char ** argv)
{

  bool helpMode;
  int argcFile = 0;
  int moverMode = 1;
  int seqId = -1;
  char ** argvFile = 0;
  double pseudo = 1.0; // Bayesian pseudo count
  int startPos = 1;
  int stopPos = 0;
  unsigned int verboseLevel = 1;
  double seqSimLimit = 0.5;
  string alphabet = "ACGT";
  string alphabetMode;
  string commandFileName;
  string logFileName; //  = "mainprogramtemplate.log";
  string posVectorString;
  string rootDir = ".";
  string targetSeq;
  string targetName;
  Vec<unsigned int> posVec;
  Vec<string> inputFileNames;
  string sequenceFileName;
  SimpleSequenceAlignment ali1, ali2, sequenceAli;

  getArg("-help", helpMode, argc, argv);

  if ((argc < 2) || helpMode)  {
    helpOutput(cout);
    exit(0);
  }

  getArg("-root", rootDir, argc, argv, rootDir);
  addSlash(rootDir);

  getArg("-commands", commandFileName, argc, argv, commandFileName);
  addPathIfRelative(commandFileName, rootDir);

  if (commandFileName.size() > 0) {
    ifstream commandFile(commandFileName.c_str());
    if (!commandFile) {
      if (isPresent("-commands", argc, argv)) {
	ERROR_IF(!commandFile, "Error opening command file.");
      }
      else {
	cerr << "Warning: Could not find command file: " + commandFileName 
	     << endl;
      }
    }
    else {
      argvFile = streamToCommands(commandFile, argcFile, 
				  string("mainprogramtemplate"));
    }
    commandFile.close();
  }
  
  getArg("-alphabet", alphabet, argcFile, argvFile, alphabet);
  getArg("-alphabet", alphabet, argc, argv, alphabet);
  getArg("-alphabet-mode", alphabetMode, argcFile, argvFile, alphabetMode);
  getArg("-alphabet-mode", alphabetMode, argc, argv, alphabetMode);
  getArg("i", inputFileNames, argc, argv);
  getArg("-log", logFileName, argc, argv, logFileName);
  getArg("-log", logFileName, argcFile, argvFile, logFileName);
  addPathIfRelative(logFileName, rootDir);
  getArg("-mover", moverMode, argc, argv, moverMode); // specifies if sequence to be predicted comes from mover alignment
  getArg("n", seqId, argc, argv, seqId);
  --seqId;
  getArg("-pos", posVectorString, argc, argv);
  posVec = parseStringToVector(posVectorString);
  convert2InternalCounting(posVec);
  getArg("s", sequenceFileName, argc, argv);
  getArg("-start", startPos, argc, argv, startPos);
  getArg("-stop", stopPos, argc, argv, stopPos);
  --startPos;
  --stopPos;
  getArg("-sim", seqSimLimit, argcFile,argvFile, seqSimLimit);
  getArg("-sim", seqSimLimit, argc,argv, seqSimLimit);
  ERROR_IF((seqSimLimit > 1.0) || (seqSimLimit < 0.0),
	   "Sequence similarity limit has to be fraction between 0 and 1!" );
  getArg("-verbose", verboseLevel, argcFile, argvFile, verboseLevel);
  getArg("-verbose", verboseLevel, argc, argv, verboseLevel);


  if (logFileName.size() > 0) {
    ofstream logFile(logFileName.c_str(), ios::app);
    parameterOutput(logFile, argc, argv);
    if (argcFile > 1) {
      logFile << "Parameters from command file: ";
      parameterOutput(logFile, argcFile, argvFile);
    }
    logFile.close();
  }


  /***************** MAIN PROGRAM *****************************/

  ERROR_IF(inputFileNames.size() != 2,
	   "2 alignment names expected!");
  ifstream inputFile1(inputFileNames[0].c_str());
  ERROR_IF(!inputFile1, "Error opening input file 1!" );
  ali1.readFasta(inputFile1);
  inputFile1.close();
  ali1.upperCaseSequences();
  cout << "SequenceAlignment 1: " << ali1.size() << " sequences with length " << ali1.getLength() << endl;

  ifstream inputFile2(inputFileNames[1].c_str());
  ERROR_IF(!inputFile2, "Error opening input file 2!");
  ali2.readFasta(inputFile2);
  inputFile2.close();
  ali2.upperCaseSequences();
  cout << "SequenceAlignment 2: " << ali2.size() << " sequences with length " << ali2.getLength() << endl;

  if (sequenceFileName.size() > 0) {
    ifstream sequenceFile(sequenceFileName.c_str());
    ERROR_IF(!sequenceFile, "Error opening sequence file!");
    sequenceAli.readFasta(sequenceFile);
    sequenceFile.close();
    ERROR_IF(sequenceAli.size() < 1, 
	     "No sequence defined in alignment!");
    targetSeq = sequenceAli.getSequence(0);
    targetName = sequenceAli.getName(0);
  }

  ERROR_IF(ali1.getLength() != ali2.getLength(),
	   "Alignments have to have the same length!");	  

  // sets alphabet to DNA, RNA with or witout gap
  if (alphabetMode.size() > 0) {
    if (alphabetMode.compare("dna") == 0) {
      alphabet = "ACGT";
    }
    else if (alphabetMode.compare("dnagap") == 0) {
      alphabet = "ACGT-";
    }
    if (alphabetMode.compare("rna") == 0) {
      alphabet = "ACGU";
    }
    if (alphabetMode.compare("rnagap") == 0) {
      alphabet = "ACGU-";
    }
    else {
      ERROR("Unrecognized alphabet mode!");
    }
  }
  cout << "# Using alphabet: " << alphabet << endl;

  if (seqId >= 0) {
    if (moverMode) {
      targetSeq = ali1.getSequence(seqId);
      targetName = ali1.getName(seqId);
      ali1.removeSequence(seqId);
    }
    else {
      targetSeq = ali2.getSequence(seqId);
      targetName = ali2.getName(seqId);
      ali2.removeSequence(seqId);
    }
  }

  if ((targetSeq.size() > 0) && (seqSimLimit > 0.0)) {
    // delete all sequences that are too dissimilar
    cout << "Removing too similar sequences!" << endl;
    removeTooSimilar(targetSeq, ali1, seqSimLimit);
    removeTooSimilar(targetSeq, ali2, seqSimLimit);
    cout << "Number of sequences after removing: " << ali1.size() << " " << ali2.size() << endl;
  }


  Vec<double> discRow(ali1.getLength(), 0.0);
  Vec<Vec<double> > avgRows = compareAlignments(ali1, ali2, alphabet, pseudo);

  for (unsigned int j = 0; j < avgRows[0].size(); ++j) {
    if (posVec.size() > 0) {
      // check if part of position:
      if (findFirstIndex(posVec, j) >= posVec.size()) {
	continue; // skip this position!
      }
    }
    cout << (j+1) << " ";
    for (unsigned int i = 0; i < avgRows.size(); ++i) {
      cout << avgRows[i][j] << " ";
    }
    cout << targetSeq[j] << " "; // print sequence character of potential target sequence
    cout << log10(alignmentColumnPreference(ali1.getColumn(j),
					    ali2.getColumn(j),
					    targetSeq[j], alphabet,
					    pseudo)); // score of this position
    cout << endl;
  }

  /*
  // Vec<Vec<double> > avgRows2 = compareAlignmentsDuos(ali1, ali2, alphabet);
    for (unsigned int j = 0; j < alphabet.size(); ++j) {
    for (unsigned int k = 0; k < alphabet.size(); ++k) {
    cout << alphabet[j] << alphabet[k] << " ";
    }
    }
    
    for (unsigned int j = 0; j < avgRows2[0].size(); ++j) {
    cout << (j+1) << " ";
    double maxi = avgRows[0][j];
    // for (unsigned int i = 0; i < avgRows2.size(); ++i) {
    for (unsigned int k1 = 0; k1 < alphabet.size(); ++k1) {
    for (unsigned int k2 = 0; k2 < alphabet.size(); ++k2) {
    unsigned int i = k1 * alphabet.size() + k2;
    cout << alphabet[k1] << alphabet[k2] << ":" 
    << setw(5) << setprecision(2) << avgRows2[i][j] << " ";
    if (fabs(avgRows2[i][j]) > maxi) {
    maxi = fabs(avgRows2[i][j]);
    }
    }
    }
    cout << " max: " << maxi << endl;
    }
  */

  if (targetSeq.size() > 0) {
    if (stopPos < startPos) {
      stopPos = targetSeq.size()-1;
    }
    double score = 0;
    if (posVec.size() == 0) {
      score = scoreSequence(targetSeq, ali1, ali2, alphabet,
			    static_cast<unsigned int>(startPos), 
			    static_cast<unsigned int>(stopPos),
			    pseudo);
    }
    else {
      score = scoreSequence(targetSeq, ali1, ali2, alphabet, posVec, pseudo);
    }
    cout << "Score of ";
    if (moverMode) {
      cout << "mover";
    }
    else {
      cout << "non-mover";
    }
    cout << " sequence " << (seqId+1) << " " << score << " " << targetName << endl;
  }

  return 0;
}
