// --*- C++ -*------x---------------------------------------------------------
// $Id: KnnNode.h,v 1.1.1.1 2006/07/03 14:43:21 bindewae Exp $
//
// Class:           KnnNode
// 
// Base class:      -
//
// Derived classes: - 
//
// Author:          Eckart Bindewald
//
// Description:     prediction using K-nearest neighbor method
// 
// Reviewed by:     -
// -----------------x-------------------x-------------------x-----------------

#ifndef __KNN_NODE_H__
#define __KNN_NODE_H__

// Includes

#include <iostream>

#include <debug.h>
#include <Vec.h>
#include <ClassifierBase.h>

/** This is the first sentence summarizing the classes' purpose.

    Here comes a *detailed* description of this class. This
    description may consist of more than one sentence.

    @author Eckart Bindewald
    @see    Modern Applied Statics with S / Ripley
    @review - */
class KnnNode : public ClassifierBase {
public:

  enum { NUM_STAT_BINS = 10 };
  
  KnnNode();

  KnnNode(const KnnNode& orig);

  virtual ~KnnNode();

  /* OPERATORS */

  /** Assigment operator. */
  KnnNode& operator = (const KnnNode& orig);

  friend ostream& operator << (ostream& os, const KnnNode& rval);

  friend istream& operator >> (istream& is, KnnNode& rval);
 

  /* PREDICATES */

  /** Is current state valid? */
  virtual bool isValid() const { return (kk > 0) && (data.size() > 0); }
  
  /** How big is object? */
  virtual unsigned int size() const { return data.size(); }

  /** return dimension of each data vector */
  virtual unsigned int getDim() const {
    if (data.size() == 0) {
      return 0;
    }
    return data[0].size();
  }

  virtual unsigned int getNumClasses() const { return numClasses; }

  /*
  virtual unsigned int predictClass(const Vec<double>& v) const;
  */

  /** central prediction method. Returns probability of each class according to 
      vote of classes of k nearest neighbors */
  virtual Vec<double> predictClassProb(const Vec<double>& v) const;

  /** central prediction method. Returns probability of each class according to 
      vote of classes of k nearest neighbors */
  virtual Vec<double> predictClassProb(const Vec<double>& v, 
				       unsigned int knownClass) const;

  virtual Vec<double> getLastPrediction() const { return lastPrediction; }

  /** returns parameter "k" of k-nearest neighbor algorithm */
  virtual unsigned int getK() const { return kk; }

  virtual const Vec<Vec<double> >& getData() const { return data; }

  /** returns data rows which belong to class dataClass */
  virtual Vec<Vec<double> > getData(unsigned int dataClass) const;

  /** returns indices of data rows which belong to class dataClass */
  virtual Vec<unsigned int> getDataIndices(unsigned int dataClass) const;

  /** returns data row n */
  virtual const Vec<double>& getDataRow(unsigned int n) const { return data[n]; }

  /** return class of data row n */
  virtual unsigned int getDataRowClass(unsigned int n) const { return dataClasses[n]; }

  /** returns scaling */
  virtual const Vec<double>& getScaling() const { return scaling; }

  /** returns prediction accuracy using leave one out estimation (numTrial times) */
  virtual double estimateAccuracy(unsigned int numTrials) const;

  /** gets cutoff for simpleRepresentativeLinkage */
  virtual double getClusterCutoff() const { return clusterCutoff; }
  
  Vec<Vec<double> > getUsageHistogram() const;

  virtual bool isNoSelfMode() const { return noSelfMode; }

  void writeData(ostream& os) const;

  virtual double getGaussDev() const { return gaussDev; }
 
  virtual unsigned int getNumClusters() const { return clusters.size(); }

  virtual int getVerboseLevel() const { return verboseLevel; }

  /* MODIFIERS */

  /** optimize scaling of node using simple Monte Carlo steps */
  virtual void optimizeScaling(int numSteps, int verboseLevel,
		       double stepWidth, unsigned int numTrials);

  /** read input data */
  virtual void readData(istream& is, 
			unsigned int startCol, 
			unsigned int endCol, 
			unsigned int classCol);

  /** read input data */
  virtual void readData(istream& is, 
			const Vec<unsigned int>& mask);

  virtual void recluster();

  /** sets cutoff for simpleRepresentativeLinkage */
  virtual void setClusterCutoff(double x) { clusterCutoff = x; }

  /** sets cutoff for simpleRepresentativeLinkage */
  virtual void setClusterCutoff2(double x) { clusterCutoff2 = x; }

  /** dangerous! Use other setData methods when possible */
//   virtual void setData(const Vec<Vec<double> >& mtx) {
//     data = mtx; 
//   }

  virtual void setData(const Vec<Vec<double> >& mtx, 
		       const Vec<unsigned int>& dClasses,
		       unsigned int _nClass,
		       const Vec<double>& _scale) { 
    PRECOND(mtx.size() == dClasses.size());
    data = mtx; dataClasses = dClasses; numClasses = _nClass; scaling = _scale;
    recluster();
  }
  
  /** sets parameter "k" of k-nearest neighbor algorithm */
  virtual void setK(unsigned int k) { kk = k; }

  virtual void setNumClasses(unsigned int n) { numClasses = n; }

  virtual void setScaling(const Vec<double>& v) { scaling = v; }

  /** start statistics mode */
  virtual void startStatistics();

  virtual void setGaussDev(double g) { gaussDev = g; }
  
  virtual void setNoSelfMode(bool b) { noSelfMode = b; }

  virtual void setVerboseLevel(int n) { verboseLevel = n; }
  
  /** reduce learned data */
  virtual void thin(unsigned int thinK);

  virtual void updateStatistics(const Vec<double>& prediction, unsigned int knownClass) const;



protected:
  /* OPERATORS  */
  /* PREDICATES */
  /* MODIFIERS  */
  virtual void copy(const KnnNode& other);

private:
  /* OPERATORS  */
  /* PREDICATES */

  /** returns prediction accuracy using leave one out estimation (numTrial times) */
  virtual Vec<unsigned int> initEstimateAccuracy() const;

  /* MODIFIERS  */

private:
  
  /* PRIVATE ATTRIBUTES */

  mutable Vec<unsigned int> estimateSet;

  double clusterCutoff;

  double clusterCutoff2;

  double gaussDev;

  unsigned int kk;

  unsigned int numClasses;

  int verboseLevel;

  Vec<double> scaling;

  Vec<Vec<double> > data;

  mutable Vec<double> lastPrediction;

  Vec<unsigned int> dataClasses;

  Vec<Vec<Vec<double> > > clustData;

  Vec<Vec<unsigned int> > clusters;

  Vec<Vec<Vec<unsigned int> > > subClusters;

  // Vec<Vec<Vec<unsigned int> > > subClustersAbs; // absolute indices

  bool noSelfMode;

  bool statisticsRunning;

  mutable Vec<Vec<unsigned int long > > trueCount; // for each class, generate 10 bins, add one if true class for certain score

  mutable Vec<Vec<unsigned int long > > falseCount; // for each class, generate 10 bins, add one if wrong class for certain score

};

#endif /* __KNN_NODE_H__ */

