diff options
Diffstat (limited to 'src/TNetLib/ObjFun.cc')
-rw-r--r-- | src/TNetLib/ObjFun.cc | 231 |
1 files changed, 231 insertions, 0 deletions
diff --git a/src/TNetLib/ObjFun.cc b/src/TNetLib/ObjFun.cc new file mode 100644 index 0000000..c899fb1 --- /dev/null +++ b/src/TNetLib/ObjFun.cc @@ -0,0 +1,231 @@ + +#include "ObjFun.h" +#include "Error.h" + +#include <limits> + +namespace TNet { + + +ObjectiveFunction* ObjectiveFunction::Factory(ObjFunType type) { + ObjectiveFunction* ret = NULL; + switch(type) { + case MEAN_SQUARE_ERROR: ret = new MeanSquareError; break; + case CROSS_ENTROPY: ret = new CrossEntropy; break; + default: Error("Unknown ObjectiveFunction type"); + } + return ret; +} + + +/* + * MeanSquareError + */ +void MeanSquareError::Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err) { + + //check dimensions + assert(net_out.Rows() == target.Rows()); + assert(net_out.Cols() == target.Cols()); + if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) { + err->Init(net_out.Rows(),net_out.Cols()); + } + + //compute global gradient + err->Copy(net_out); + err->AddScaled(-1,target); + + //compute loss function + double sum = 0; + for(size_t r=0; r<err->Rows(); r++) { + for(size_t c=0; c<err->Cols(); c++) { + BaseFloat val = (*err)(r,c); + sum += val*val; + } + } + error_ += sum/2.0; + frames_ += net_out.Rows(); +} + + +std::string MeanSquareError::Report() { + std::stringstream ss; + ss << "Mse:" << error_ << " frames:" << frames_ + << " err/frm:" << error_/frames_ + << "\n"; + return ss.str(); +} + + +/* + * CrossEntropy + */ + +///Find maximum in float array +inline int FindMaxId(const BaseFloat* ptr, size_t N) { + BaseFloat mval = -1e20f; + int mid = -1; + for(size_t i=0; i<N; i++) { + if(ptr[i] > mval) { + mid = i; mval = ptr[i]; + } + } + return mid; +} + + +void +CrossEntropy::Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err) +{ + //check dimensions + assert(net_out.Rows() == target.Rows()); + assert(net_out.Cols() == target.Cols()); + if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) { + err->Init(net_out.Rows(),net_out.Cols()); + } + + //allocate confunsion buffers + if(confusion_mode_ != NO_CONF) { + if(confusion_.Rows() != target.Cols() || confusion_.Cols() != target.Cols()) { + confusion_.Init(target.Cols(),target.Cols()); + confusion_count_.Init(target.Cols()); + diag_confusion_.Init(target.Cols()); + } + } + + //compute global gradient (assuming on softmax input) + err->Copy(net_out); + err->AddScaled(-1,target); + + //collect max values + std::vector<size_t> max_target_id(target.Rows()); + std::vector<size_t> max_netout_id(target.Rows()); + //check correct classification + int corr = 0; + for(size_t r=0; r<net_out.Rows(); r++) { + int id_netout = FindMaxId(net_out[r].pData(),net_out.Cols()); + int id_target = FindMaxId(target[r].pData(),target.Cols()); + if(id_netout == id_target) corr++; + max_target_id[r] = id_target;//store the max value + max_netout_id[r] = id_netout; + } + + //compute loss function + double sumerr = 0; + for(size_t r=0; r<net_out.Rows(); r++) { + if(target(r,max_target_id[r]) == 1.0) { + //pick the max value..., rest is zero + BaseFloat val = log(net_out(r,max_target_id[r])); + if(val < -1e10f) val = -1e10f; + sumerr += val; + } else { + //process whole posterior vect. + for(size_t c=0; c<net_out.Cols(); c++) { + if(target(r,c) != 0.0) { + BaseFloat val = target(r,c)*log(net_out(r,c)); + if(val < -1e10f) val = -1e10f; + sumerr += val; + } + } + } + } + + //accumulate confusuion network + if(confusion_mode_ != NO_CONF) { + for(size_t r=0; r<net_out.Rows(); r++) { + int id_target = max_target_id[r]; + int id_netout = max_netout_id[r]; + switch(confusion_mode_) { + case MAX_CONF: + confusion_(id_target,id_netout) += 1; + break; + case SOFT_CONF: + confusion_[id_target].Add(net_out[r]); + break; + case DIAG_MAX_CONF: + diag_confusion_[id_target] += ((id_target==id_netout)?1:0); + break; + case DIAG_SOFT_CONF: + diag_confusion_[id_target] += net_out[r][id_target]; + break; + default: + KALDI_ERR << "unknown confusion type" << confusion_mode_; + } + confusion_count_[id_target] += 1; + } + } + + error_ -= sumerr; + frames_ += net_out.Rows(); + corr_ += corr; +} + + +std::string CrossEntropy::Report() { + std::stringstream ss; + ss << "Xent:" << error_ << " frames:" << frames_ + << " err/frm:" << error_/frames_ + << " correct[" << 100.0*corr_/frames_ << "%]" + << "\n"; + + if(confusion_mode_ != NO_CONF) { + //read class tags + std::vector<std::string> tag; + { + std::ifstream ifs(output_label_map_); + assert(ifs.good()); + std::string str; + while(!ifs.eof()) { + ifs >> str; + tag.push_back(str); + } + } + assert(confusion_count_.Dim() <= tag.size()); + + //print confusion matrix + if(confusion_mode_ == MAX_CONF || confusion_mode_ == SOFT_CONF) { + ss << "Row:label Col:hyp\n" << confusion_ << "\n"; + } + + //***print per-target accuracies + for(int i=0; i<confusion_count_.Dim(); i++) { + //get the numerator + BaseFloat numerator = 0.0; + switch (confusion_mode_) { + case MAX_CONF: case SOFT_CONF: + numerator = confusion_[i][i]; + break; + case DIAG_MAX_CONF: case DIAG_SOFT_CONF: + numerator = diag_confusion_[i]; + break; + default: + KALDI_ERR << "Usupported confusion mode:" << confusion_mode_; + } + //add line to report + ss << std::setw(30) << tag[i] << " " + << std::setw(10) << 100.0*numerator/confusion_count_[i] << "%" + << " [" << numerator << "/" << confusion_count_[i] << "]\n"; + } //***print per-target accuracies + }// != NO_CONF + + return ss.str(); +} + + +void CrossEntropy::MergeStats(const ObjectiveFunction& inst) { + const CrossEntropy& xent = dynamic_cast<const CrossEntropy&>(inst); + frames_ += xent.frames_; error_ += xent.error_; corr_ += xent.corr_; + //sum the confustion statistics + if(confusion_mode_ != NO_CONF) { + if(confusion_.Rows() != xent.confusion_.Rows()) { + confusion_.Init(xent.confusion_.Rows(),xent.confusion_.Cols()); + confusion_count_.Init(xent.confusion_count_.Dim()); + diag_confusion_.Init(xent.diag_confusion_.Dim()); + } + confusion_.Add(xent.confusion_); + confusion_count_.Add(xent.confusion_count_); + diag_confusion_.Add(xent.diag_confusion_); + } +} + + +} // namespace TNet |