#include <sequencestat.h>
#include <math.h>
#include <Limits.h>

/** generates R-dim continuous space point for 
    sequence i, column k letter
    @see A new Objective Function for Multiple Alignment, JMB 2001
*/
static
Vec<double>
generateS(int aik, const Vec<Vec<int> >& sub)
{
  Vec<double> result(sub.size());
  for (unsigned int i = 0; i < sub.size(); ++i) {
    ASSERT(aik < static_cast<int>(sub[i].size()));
    result[i] = sub[i][aik];
  }
  return result;
}

/** returns distance between two letters */
static
double
distij(int aik, int ajk, const Vec<Vec<int> >& sub)
{
  Vec<double> s1 = generateS(aik, sub);
  Vec<double> s2 = generateS(ajk, sub);
  double result = 0.0;
  double term = 0.0;
  for (unsigned int i = 0; i < s1.size(); ++i) {
    term = s1[0] - s2[0];
    result += term * term;
  }
  return result;
}

/** returns distance between two letters */
static
double
distijFast(int aik, int ajk, const Vec<Vec<double> >& sub2)
{
  PRECOND((aik < static_cast<int>(sub2.size())) 
	  && (ajk < static_cast<int>(sub2.size())));
  return sub2[aik][ajk];
}

/** returns number of identical residues
 gaps are denoted by negative numbers */
int
countNumIdentical(const Vec<int>& v1, const Vec<int>& v2)
{
  PRECOND(v1.size() == v2.size());
  int count = 0;
  for (unsigned int i = 0; i < v1.size(); ++i) {
    if ((v1[i] == v2[i]) && (v1[i] >= 0)) {
      ++count;
    }
  }
  return count;
}

double
computeQ(const Vec<Vec<int> >& ali, unsigned int col,
	 const Vec<Vec<int> >& sub)
{
  PRECOND(ali.size() > 0);
  double enume  = 0.0;
  double denom = 0.0;
  double w, dij;
  double length = ali[0].size();
  for (unsigned int i = 0; i < ali.size(); ++i) {
    for (unsigned int j = i+1; j < ali.size(); ++j) {
      if ((ali[i][col] < 0) || (ali[j][col] < 0)) {
	continue; // skip gaps
      }
      w = 1.0 - (countNumIdentical(ali[i], ali[j]) / length);
      dij = distij(ali[i][col], ali[j][col], sub);
      enume += w * dij;
      denom += w;
    }
  }
  double result = enume;
  if (denom > 0.0) {
    result /= denom;
  }
  return result;
}

/** sub2 is not the normal substition matrix, but a matrix with Vingron distances 
    countIdent contains the number of identical residues between lines i and j in alignment as matrix element i, j
 */
double
computeQFast(const Vec<Vec<int> >& ali, unsigned int col,
	     const Vec<Vec<double> >& sub2,
	     const Vec<Vec<int> >& countIdent )
{
  PRECOND(ali.size() > 0);
  double enume  = 0.0;
  double denom = 0.0;
  double w, dij;
  double length = ali[0].size();
  for (unsigned int i = 0; i < ali.size(); ++i) {
    for (unsigned int j = i+1; j < ali.size(); ++j) {
      if ((ali[i][col] < 0) || (ali[j][col] < 0)) {
	continue; // skip gaps
      }
      w = 1.0 - (countIdent[i][j] / length);
      dij = distijFast(ali[i][col], ali[j][col], sub2);
      enume += w * dij;
      denom += w;
    }
  }
  double result = enume;
  if (denom > 0.0) {
    result /= denom;
  }
  return result;
}


/** Generates matrix with Vingron style residue distances. 
    @see : Stochastic pairwise matrices.
 */
Vec<Vec<double> >
generateSpaceMatrix(const Vec<Vec<int> >& sub)
{
  Vec<Vec<double> > sub2(sub.size(), Vec<double>(sub[0].size()));
  for (unsigned int i = 0; i < sub2.size(); ++i) {
    for (unsigned int j = 0; j < sub2[i].size(); ++j) {
      sub2[i][j] = distij(i, j, sub);
    }
  }
  return sub2;
}

Vec<Vec<int> >
generateIdentityMatrix(const Vec<Vec<int> >& ali)
{
  Vec<Vec<int> > result(ali.size(), Vec<int>(ali.size(), 0));
  for (unsigned int i = 0; i < ali.size(); ++i) {
    for (unsigned int j = 0; j <= i; ++j) {
      result[i][j] = countNumIdentical(ali[i], ali[j]);
      result[j][i] = result[i][j];
    }
  }
  return result;
}

/** counts number of non-gap characters at column col */
unsigned int
countNoGaps(const Vec<Vec<int> >& ali, unsigned int col)
{
  unsigned int counter = 0;
  for (unsigned int i = 0; i < ali.size(); ++i) {
    ASSERT(col < ali[i].size());
    if (ali[i][col] >= 0) {
      ++counter;
    }
  }
  return counter;
}

double
computeMD(const Vec<Vec<int> >& ali,
	  const Vec<Vec<int> >& sub)
{
  if (ali.size() == 0) {
    return 0.0;
  }
  Vec<Vec<double> > sub2 = generateSpaceMatrix(sub);
  Vec<Vec<int> > numIdentMatrix = generateIdentityMatrix(ali);
  unsigned int len = ali[0].size();
  double result = 0.0;
  double term = 0.0;
  double fracNoGaps = 0.0;
  for (unsigned int i = 0; i < len; ++i) {
    term = exp(-computeQFast(ali, i, sub2, numIdentMatrix));
    // normalize term: multiply by percentage of sequences, that have no gaps at position:
    fracNoGaps = countNoGaps(ali, i) / static_cast<double>(ali.size());
    result += (fracNoGaps * term);
  }
  ERROR_IF(!isReasonable(result), "internal error in line 180!");
  // cout << "Result of MD function: " << result << endl;
  return result;
}


double
computeMDFast(const Vec<Vec<int> >& ali,
	      const Vec<Vec<double> >& sub2,
	      const Vec<Vec<int> >& identMatrix)
{
  if (ali.size() == 0) {
    return 0.0;
  }
  unsigned int len = ali[0].size();
  double result = 0.0;
  double term = 0.0;
  double fracNoGaps = 0.0;
  for (unsigned int i = 0; i < len; ++i) {
    term = exp(-computeQFast(ali, i, sub2, identMatrix));
    // normalize term: multiply by percentage of sequences, that have no gaps at position:
    fracNoGaps = countNoGaps(ali, i) / static_cast<double>(ali.size());
    result += (fracNoGaps * term);
  }
  ERROR_IF(!isfinite(result), "internal error in line 180!");
  // cout << "Result of MD fast function: " << result << endl;
  return result;
}

unsigned int 
sequenceLength(const Vec<int>& s)
{
  unsigned int counter = 0;
  for (unsigned int i = 0; i < s.size(); ++i) {
    if (s[i] >= 0) {
      ++counter;
    }
  }
  return counter;
}

Vec<unsigned int>
sequenceLengths(const Vec<Vec<int> >& ali)
{
  Vec<unsigned int> result(ali.size());
  for (unsigned int i = 0; i < ali.size(); ++i) {
    result[i] = sequenceLength(ali[i]);
  }
  return result;
}


double
computeMaxMD(const Vec<Vec<int> >& ali)
{
  ERROR_IF(ali.size() == 0,
	   "Undefined alignment in computeMaxMd!");
//   if (ali.size() == 0) {
//     return 0.0;
//   }
  // find longest sequence:
  // unsigned int longestId = 0;

  // unsigned int len = 0;
  Vec<unsigned int> lengths = sequenceLengths(ali);
  sort(lengths.begin(), lengths.end());
  unsigned int longestLength = lengths[lengths.size()-1];
  unsigned int counter = 0;
  double sum = 0.0;
  // cout << "Lengths in MaxMd: " << lengths << endl;
  for (unsigned int i = 1; i <= longestLength; ++i) {
    for ( ; counter < lengths.size(); ++counter) {
      if (lengths[counter] > i) {
	// 	cout << "Number of sequence smaller than " << i << " " << lengths[counter] << " "
	// 	     << counter << endl;
	break;
      }
    }
    sum += counter;
    // cout << "Result of sum: " << sum << endl;
  }
  sum /= ali.size();
  return sum;
}


/** returns vector with gaps that are new or longer compared to sOrig */
Vec<pair<unsigned int, unsigned int> >
findGaps(const Vec<int>& s)
{
  Vec<pair<unsigned int, unsigned int> > gaps;
  unsigned int nonGaps = 0;
  for (unsigned int i = 0; i < s.size(); ++i) {
    if (s[i] < 0) {
      unsigned int len = 1;
      for (unsigned int j = i + 1; j < s.size(); ++j, ++len) {
	if (s[j] >= 0) {
	  break;
	}
      }
      gaps.push_back(pair<unsigned int, unsigned int>(nonGaps, len)); 
      i += len;
      ++nonGaps;
    }
    else {
      ++nonGaps; // how many non-gap characters left from position
    }
  }
  return gaps;
}

/** removes redundant gaps */
void
purgeGaps(Vec<int>& s1, Vec<int>& s2)
{
  PRECOND(s1.size() == s2.size());
  for (int i =  s1.size()-1; i >= 0; --i) {
    if ((s1[i] < 0) && (s2[i] < 0)) {
      s1.erase(s1.begin()+i);
      s2.erase(s2.begin()+i);
    }
  }
}


double
computeGapCost(const Vec<int>& s1Orig,
	       const Vec<int>& s2Orig,
	       double gOpen, double gExt)
{
  Vec<int> s1 = s1Orig;
  Vec<int> s2 = s2Orig;
  purgeGaps(s1, s2);
  Vec<pair<unsigned int, unsigned int> > gaps1 = findGaps(s1);
  Vec<pair<unsigned int, unsigned int> > gaps2 = findGaps(s2);
  unsigned int totLen = 0;
  for (unsigned int i = 0; i < gaps1.size(); ++i) {
    totLen += gaps1[i].second;
  }
  for (unsigned int i = 0; i < gaps2.size(); ++i) {
    totLen += gaps2[i].second;
  }
  unsigned int totNum = gaps1.size() + gaps2.size();
  return ((gOpen * totNum) + (gExt * totLen));
}

double
computeGapCost(const Vec<Vec<int> >& ali,
	       double gOpen, double gExt)
{
  double sum = 0.0;
  unsigned int n = ali.size();
  if (n == 0) {
    return 0.0;
  }
  for (unsigned int i = 0; i < ali.size(); ++i) {
    for (unsigned int j = i+1; j < ali.size(); ++j) {
      sum += computeGapCost(ali[i], ali[j], gOpen, gExt);
    }
  }
  return (2 * sum / (n * (n-1)));
}
	     

/** translates vector of strings into vector of vector of int
    First letter in alphabet gets  translated into a zero,
    second letter into a one and so forth
    Example: ACCG with alphabet ACGT gets translated into 0 1 1 2
*/
Vec<int>
translateAliStringToInt(const string& ali, 
			const string& alphabet)
{
  Vec<int> result(ali.size());
  for (unsigned int i = 0; i < ali.size(); ++i) {
    result[i] = alphabet.find(ali[i]);
  }
  return result;
}


/** translates vector of strings into vector of vector of int
    First letter in alphabet gets  translated into a zero,
    second letter into a one and so forth
    Example: ACCG with alphabet ACGT gets translated into 0 1 1 2
*/
Vec<Vec<int> >
translateAliStringToInt(const Vec<string>& ali, 
			const string& alphabet)
{
  Vec<Vec<int> > result(ali.size());
  for (unsigned int i = 0; i < ali.size(); ++i) {
    result[i] = translateAliStringToInt(ali[i], alphabet);
  }
  return result;
}

/** see paper by Thompson 2001 JMB. Still not implemented factor "LQRID" of paper */
double
computeNormMD(const Vec<Vec<int> >& ali,
	      const Vec<Vec<int> >& sub,
	      double gOpen, double gExt)
{
  if (ali.size() < 2) {
    return 0.0;
  }
  double maxMd = computeMaxMD(ali);
  double gapCost = computeGapCost(ali, gOpen, gExt);
  double normMd = (computeMD(ali, sub) - gapCost) / maxMd;
  // cout << "Result of NormMD function: " << normMd << endl;
  return normMd;
}


/** see paper by Thompson 2001 JMB. Still not implemented factor "LQRID" of paper */
double
computeNormMD(const Vec<string>& aliOrig,
	      const Vec<Vec<int> >& sub,
	      const string& alphabet,
	      double gOpen, double gExt)
{
  if (aliOrig.size() < 2) {
    return 0.0;
  }
  Vec<Vec<int> > ali = translateAliStringToInt(aliOrig, alphabet);
  double maxMd = computeMaxMD(ali);
  ERROR_IF(maxMd <= 0.0, "MaxMD smaller zero encountered");
  double gapCost = computeGapCost(ali, gOpen, gExt);
  double md = computeMD(ali, sub);
  double normMd = (md - gapCost) / maxMd;
//   cout << "norm md gap max: " << normMd << " " << md << " " 
//        << gapCost << " " << maxMd << endl;
  ERROR_IF(!isReasonable(normMd), "Internal error in line 400");
  // cout << "Result of NormMD function: " << normMd << endl;
  return normMd;
}



double
computeMD(const Vec<string>& aliOrig,
	  const Vec<Vec<int> >& sub,
	  const string& alphabet)
{
  if (aliOrig.size() < 2) {
    return 0.0;
  }
  Vec<Vec<int> > ali = translateAliStringToInt(aliOrig, alphabet);
  return computeMD(ali, sub);
}


/** returns true if all characters are part of alignment alphabet */
bool
checkAlignmentString(const string& s, const string& alphabet,
		     char gapChar)
{
  for (unsigned int i = 0; i < s.size(); ++i) {
    if ((s[i] != gapChar) && (alphabet.find(s[i]) >= alphabet.size())) {
      return false;
    }
  }
  return true;
}


/** returns true if all characters are part of alignment alphabet, replaces bad characters with repairChar */
bool
checkAndRepairAlignmentString(string& s, const string& alphabet,
		     char gapChar, char repairChar)
{
  for (unsigned int i = 0; i < s.size(); ++i) {
    if ((s[i] != gapChar) && (alphabet.find(s[i]) >= alphabet.size())) {
      s[i] = repairChar;
    }
  }
  return true;
}

/** returns zero for correct alignment, error code otherwise */
string
checkAlignment(SequenceAlignment& ali, const string& alphabet, char gapChar)
{
  string errorMsg;
  if (ali.size() == 0) {
    errorMsg = "Zero size of alignment";
    return errorMsg;
  }
  unsigned int len = ali.getSequence(0).size();
  for (unsigned int i = 0; i < ali.size(); ++i) {
    if (ali.getSequence(i).size() != len) {
      errorMsg = "Sequences have different lengths!";
      return errorMsg;
    }
    if (!checkAlignmentString(ali.getSequence(i), alphabet, gapChar)) {
      cerr << "Oops: " << ali.getSequence(i) << endl;
      errorMsg = "Illegal character in sequence " + ali.getSequence(i);
      ERROR("Illegal character in sequence found!");
      return errorMsg;
    }
  }
  return errorMsg;
}

/** returns zero for correct alignment, error code otherwise */
string
checkAndRepairAlignment(SequenceAlignment& ali, const string& alphabet, char gapChar, char repairChar)
{
  string errorMsg;
  if (ali.size() == 0) {
    errorMsg = "Zero size of alignment";
    return errorMsg;
  }
  unsigned int len = ali.getSequence(0).size();
  for (unsigned int i = 0; i < ali.size(); ++i) {
    string s = ali.getSequence(i);
    if (s.size() != len) {
      errorMsg = "Sequences have different lengths!";
      cout << len << " " << s.size() << endl
	   << s << endl;
      return errorMsg;
    }
    if (!checkAndRepairAlignmentString(s, alphabet, gapChar, repairChar)) {
      errorMsg = "Illegal character in sequence " + ali.getSequence(i);
    }
    ali.setSequence(s, ali.getName(i), i);
  }
  return errorMsg;
}

void
cleanAlignment(SequenceAlignment& ali)
{
  bool allOk = false;  
  while (!allOk) {
    allOk = true;
    for (unsigned int i = 0; i < ali.size(); ++i) {
      string name = ali.getName(i);
      for (unsigned int j = i+1; j < ali.size(); ++j) {
	if (name.compare(ali.getName(j)) == 0) {
	  cout << "Template name duplicate " << name << " " 
	       << i + 1 << " " << j + 1 << endl;
	  ali.removeSequence(i);
	  allOk = false;
	  break;
	}
      }
    }
  }
  // replace '.' character to '-' for gap characters
  ali.replace('.', '-');
}

