#include "ObjFun.h" #include "Error.h" #include namespace TNet { ObjectiveFunction* ObjectiveFunction::Factory(ObjFunType type) { ObjectiveFunction* ret = NULL; switch(type) { case MEAN_SQUARE_ERROR: ret = new MeanSquareError; break; case CROSS_ENTROPY: ret = new CrossEntropy; break; default: Error("Unknown ObjectiveFunction type"); } return ret; } /* * MeanSquareError */ void MeanSquareError::Evaluate(const Matrix& net_out, const Matrix& target, Matrix* err) { //check dimensions assert(net_out.Rows() == target.Rows()); assert(net_out.Cols() == target.Cols()); if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) { err->Init(net_out.Rows(),net_out.Cols()); } //compute global gradient err->Copy(net_out); err->AddScaled(-1,target); //compute loss function double sum = 0; for(size_t r=0; rRows(); r++) { for(size_t c=0; cCols(); c++) { BaseFloat val = (*err)(r,c); sum += val*val; } } error_ += sum/2.0; frames_ += net_out.Rows(); } std::string MeanSquareError::Report() { std::stringstream ss; ss << "Mse:" << error_ << " frames:" << frames_ << " err/frm:" << error_/frames_ << "\n"; return ss.str(); } /* * CrossEntropy */ ///Find maximum in float array inline int FindMaxId(const BaseFloat* ptr, size_t N) { BaseFloat mval = -1e20f; int mid = -1; for(size_t i=0; i mval) { mid = i; mval = ptr[i]; } } return mid; } void CrossEntropy::Evaluate(const Matrix& net_out, const Matrix& target, Matrix* err) { //check dimensions assert(net_out.Rows() == target.Rows()); assert(net_out.Cols() == target.Cols()); if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) { err->Init(net_out.Rows(),net_out.Cols()); } //allocate confunsion buffers if(confusion_mode_ != NO_CONF) { if(confusion_.Rows() != target.Cols() || confusion_.Cols() != target.Cols()) { confusion_.Init(target.Cols(),target.Cols()); confusion_count_.Init(target.Cols()); diag_confusion_.Init(target.Cols()); } } //compute global gradient (assuming on softmax input) err->Copy(net_out); err->AddScaled(-1,target); //collect max values std::vector max_target_id(target.Rows()); std::vector max_netout_id(target.Rows()); //check correct classification int corr = 0; for(size_t r=0; r tag; { std::ifstream ifs(output_label_map_); assert(ifs.good()); std::string str; while(!ifs.eof()) { ifs >> str; tag.push_back(str); } } assert(confusion_count_.Dim() <= tag.size()); //print confusion matrix if(confusion_mode_ == MAX_CONF || confusion_mode_ == SOFT_CONF) { ss << "Row:label Col:hyp\n" << confusion_ << "\n"; } //***print per-target accuracies for(int i=0; i(inst); frames_ += xent.frames_; error_ += xent.error_; corr_ += xent.corr_; //sum the confustion statistics if(confusion_mode_ != NO_CONF) { if(confusion_.Rows() != xent.confusion_.Rows()) { confusion_.Init(xent.confusion_.Rows(),xent.confusion_.Cols()); confusion_count_.Init(xent.confusion_count_.Dim()); diag_confusion_.Init(xent.diag_confusion_.Dim()); } confusion_.Add(xent.confusion_); confusion_count_.Add(xent.confusion_count_); diag_confusion_.Add(xent.diag_confusion_); } } } // namespace TNet