#include "ObjFun.h"
#include "Error.h"

#include <limits>

namespace TNet {


ObjectiveFunction* ObjectiveFunction::Factory(ObjFunType type) {
  ObjectiveFunction* ret = NULL;
  switch(type) {
    case MEAN_SQUARE_ERROR: ret = new MeanSquareError;    break;
    case CROSS_ENTROPY:     ret = new CrossEntropy;       break;
    default: Error("Unknown ObjectiveFunction type");
  }
  return ret;
}


/*
 * MeanSquareError
 */
void MeanSquareError::Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err) {
  
  //check dimensions
  assert(net_out.Rows() == target.Rows());
  assert(net_out.Cols() == target.Cols());
  if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) {
    err->Init(net_out.Rows(),net_out.Cols());
  }

  //compute global gradient
  err->Copy(net_out);
  err->AddScaled(-1,target);

  //compute loss function
  double sum = 0;
  for(size_t r=0; r<err->Rows(); r++) {
    for(size_t c=0; c<err->Cols(); c++) {
      BaseFloat val = (*err)(r,c);
      sum += val*val;
    }
  }
  error_ += sum/2.0;
  frames_ += net_out.Rows();
}


std::string MeanSquareError::Report() {
  std::stringstream ss;
  ss << "Mse:" << error_ << " frames:" << frames_
     << " err/frm:" << error_/frames_
     << "\n";
  return ss.str();
}


/*
 * CrossEntropy
 */

///Find maximum in float array
inline int FindMaxId(const BaseFloat* ptr, size_t N) {
  BaseFloat mval = -1e20f;
  int mid = -1;
  for(size_t i=0; i<N; i++) {
    if(ptr[i] > mval) {
      mid = i; mval = ptr[i];
    }
  }
  return mid;
}


void
CrossEntropy::Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err)
{
  //check dimensions
  assert(net_out.Rows() == target.Rows());
  assert(net_out.Cols() == target.Cols());
  if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) {
    err->Init(net_out.Rows(),net_out.Cols());
  }

  //allocate confunsion buffers
  if(confusion_mode_ != NO_CONF) {
    if(confusion_.Rows() != target.Cols() || confusion_.Cols() != target.Cols()) {
      confusion_.Init(target.Cols(),target.Cols());
      confusion_count_.Init(target.Cols());
      diag_confusion_.Init(target.Cols());
    }
  }

  //compute global gradient (assuming on softmax input)
  err->Copy(net_out);
  err->AddScaled(-1,target);

  //collect max values
  std::vector<size_t> max_target_id(target.Rows());
  std::vector<size_t> max_netout_id(target.Rows());
  //check correct classification
  int corr = 0;
  for(size_t r=0; r<net_out.Rows(); r++) {
    int id_netout = FindMaxId(net_out[r].pData(),net_out.Cols());
    int id_target = FindMaxId(target[r].pData(),target.Cols());
    if(id_netout == id_target) corr++;
    max_target_id[r] = id_target;//store the max value
    max_netout_id[r] = id_netout;
  }

  //compute loss function
  double sumerr = 0;
  for(size_t r=0; r<net_out.Rows(); r++) {
    if(target(r,max_target_id[r]) == 1.0) {
      //pick the max value..., rest is zero
      BaseFloat val = log(net_out(r,max_target_id[r]));
      if(val < -1e10f) val = -1e10f;
      sumerr += val;
    } else {
      //process whole posterior vect.
      for(size_t c=0; c<net_out.Cols(); c++) {
        if(target(r,c) != 0.0) {
          BaseFloat val = target(r,c)*log(net_out(r,c));
          if(val < -1e10f) val = -1e10f;
          sumerr += val;
        }
      }
    }
  }

  //accumulate confusuion network
  if(confusion_mode_ != NO_CONF) {
    for(size_t r=0; r<net_out.Rows(); r++) {
      int id_target = max_target_id[r];
      int id_netout = max_netout_id[r];
      switch(confusion_mode_) {
        case MAX_CONF:
          confusion_(id_target,id_netout) += 1;
          break;
        case SOFT_CONF:
          confusion_[id_target].Add(net_out[r]);
          break;
        case DIAG_MAX_CONF:
          diag_confusion_[id_target] += ((id_target==id_netout)?1:0);
          break;
        case DIAG_SOFT_CONF:
          diag_confusion_[id_target] += net_out[r][id_target];
          break;
        default:
          KALDI_ERR << "unknown confusion type" << confusion_mode_;
      }
      confusion_count_[id_target] += 1;
    }
  }

  error_ -= sumerr;
  frames_ += net_out.Rows();
  corr_ += corr;
}


std::string CrossEntropy::Report() {
  std::stringstream ss;
  ss << "Xent:" << error_ << " frames:" << frames_
     << " err/frm:" << error_/frames_
     << " correct[" << 100.0*corr_/frames_ << "%]"
     << "\n";

  if(confusion_mode_ != NO_CONF) {
    //read class tags
    std::vector<std::string> tag;
    { 
      std::ifstream ifs(output_label_map_);
      assert(ifs.good());
      std::string str;
      while(!ifs.eof()) {
        ifs >> str;
        tag.push_back(str);
      }
    }
    assert(confusion_count_.Dim() <= tag.size());

    //print confusion matrix
    if(confusion_mode_ == MAX_CONF || confusion_mode_ == SOFT_CONF) {
      ss << "Row:label Col:hyp\n" << confusion_ << "\n";
    }
    
    //***print per-target accuracies
    for(int i=0; i<confusion_count_.Dim(); i++) {
      //get the numerator
      BaseFloat numerator = 0.0;
      switch (confusion_mode_) {
        case MAX_CONF: case SOFT_CONF:
          numerator = confusion_[i][i];
          break;
        case DIAG_MAX_CONF: case DIAG_SOFT_CONF:
          numerator = diag_confusion_[i];
          break;
        default:
          KALDI_ERR << "Usupported confusion mode:" << confusion_mode_;
      }
      //add line to report
      ss << std::setw(30) << tag[i] << " " 
         << std::setw(10) << 100.0*numerator/confusion_count_[i] << "%" 
         << " [" << numerator << "/" << confusion_count_[i] << "]\n";
    } //***print per-target accuracies
  }// != NO_CONF

  return ss.str();
}


void CrossEntropy::MergeStats(const ObjectiveFunction& inst) { 
  const CrossEntropy& xent = dynamic_cast<const CrossEntropy&>(inst);
  frames_ += xent.frames_; error_ += xent.error_; corr_ += xent.corr_;
  //sum the confustion statistics
  if(confusion_mode_ != NO_CONF) {
    if(confusion_.Rows() != xent.confusion_.Rows()) {
      confusion_.Init(xent.confusion_.Rows(),xent.confusion_.Cols());
      confusion_count_.Init(xent.confusion_count_.Dim());
      diag_confusion_.Init(xent.diag_confusion_.Dim());
    }
    confusion_.Add(xent.confusion_);
    confusion_count_.Add(xent.confusion_count_);
    diag_confusion_.Add(xent.diag_confusion_);
  }
}
 

} // namespace TNet