From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- src/TNetLib/.svn/text-base/ObjFun.cc.svn-base | 231 ++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 src/TNetLib/.svn/text-base/ObjFun.cc.svn-base (limited to 'src/TNetLib/.svn/text-base/ObjFun.cc.svn-base') diff --git a/src/TNetLib/.svn/text-base/ObjFun.cc.svn-base b/src/TNetLib/.svn/text-base/ObjFun.cc.svn-base new file mode 100644 index 0000000..c899fb1 --- /dev/null +++ b/src/TNetLib/.svn/text-base/ObjFun.cc.svn-base @@ -0,0 +1,231 @@ + +#include "ObjFun.h" +#include "Error.h" + +#include + +namespace TNet { + + +ObjectiveFunction* ObjectiveFunction::Factory(ObjFunType type) { + ObjectiveFunction* ret = NULL; + switch(type) { + case MEAN_SQUARE_ERROR: ret = new MeanSquareError; break; + case CROSS_ENTROPY: ret = new CrossEntropy; break; + default: Error("Unknown ObjectiveFunction type"); + } + return ret; +} + + +/* + * MeanSquareError + */ +void MeanSquareError::Evaluate(const Matrix& net_out, const Matrix& target, Matrix* err) { + + //check dimensions + assert(net_out.Rows() == target.Rows()); + assert(net_out.Cols() == target.Cols()); + if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) { + err->Init(net_out.Rows(),net_out.Cols()); + } + + //compute global gradient + err->Copy(net_out); + err->AddScaled(-1,target); + + //compute loss function + double sum = 0; + for(size_t r=0; rRows(); r++) { + for(size_t c=0; cCols(); c++) { + BaseFloat val = (*err)(r,c); + sum += val*val; + } + } + error_ += sum/2.0; + frames_ += net_out.Rows(); +} + + +std::string MeanSquareError::Report() { + std::stringstream ss; + ss << "Mse:" << error_ << " frames:" << frames_ + << " err/frm:" << error_/frames_ + << "\n"; + return ss.str(); +} + + +/* + * CrossEntropy + */ + +///Find maximum in float array +inline int FindMaxId(const BaseFloat* ptr, size_t N) { + BaseFloat mval = -1e20f; + int mid = -1; + for(size_t i=0; i mval) { + mid = i; mval = ptr[i]; + } + } + return mid; +} + + +void +CrossEntropy::Evaluate(const Matrix& net_out, const Matrix& target, Matrix* err) +{ + //check dimensions + assert(net_out.Rows() == target.Rows()); + assert(net_out.Cols() == target.Cols()); + if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) { + err->Init(net_out.Rows(),net_out.Cols()); + } + + //allocate confunsion buffers + if(confusion_mode_ != NO_CONF) { + if(confusion_.Rows() != target.Cols() || confusion_.Cols() != target.Cols()) { + confusion_.Init(target.Cols(),target.Cols()); + confusion_count_.Init(target.Cols()); + diag_confusion_.Init(target.Cols()); + } + } + + //compute global gradient (assuming on softmax input) + err->Copy(net_out); + err->AddScaled(-1,target); + + //collect max values + std::vector max_target_id(target.Rows()); + std::vector max_netout_id(target.Rows()); + //check correct classification + int corr = 0; + for(size_t r=0; r tag; + { + std::ifstream ifs(output_label_map_); + assert(ifs.good()); + std::string str; + while(!ifs.eof()) { + ifs >> str; + tag.push_back(str); + } + } + assert(confusion_count_.Dim() <= tag.size()); + + //print confusion matrix + if(confusion_mode_ == MAX_CONF || confusion_mode_ == SOFT_CONF) { + ss << "Row:label Col:hyp\n" << confusion_ << "\n"; + } + + //***print per-target accuracies + for(int i=0; i(inst); + frames_ += xent.frames_; error_ += xent.error_; corr_ += xent.corr_; + //sum the confustion statistics + if(confusion_mode_ != NO_CONF) { + if(confusion_.Rows() != xent.confusion_.Rows()) { + confusion_.Init(xent.confusion_.Rows(),xent.confusion_.Cols()); + confusion_count_.Init(xent.confusion_count_.Dim()); + diag_confusion_.Init(xent.diag_confusion_.Dim()); + } + confusion_.Add(xent.confusion_); + confusion_count_.Add(xent.confusion_count_); + diag_confusion_.Add(xent.diag_confusion_); + } +} + + +} // namespace TNet -- cgit v1.2.3-70-g09d2