summaryrefslogtreecommitdiff
path: root/src/TNetLib/.svn/text-base/ObjFun.cc.svn-base
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
commitcccccbf6cca94a3eaf813b4468453160e91c332b (patch)
tree23418cb73a10ae3b0688681a7f0ba9b06424583e /src/TNetLib/.svn/text-base/ObjFun.cc.svn-base
downloadtnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip
First commit
Diffstat (limited to 'src/TNetLib/.svn/text-base/ObjFun.cc.svn-base')
-rw-r--r--src/TNetLib/.svn/text-base/ObjFun.cc.svn-base231
1 files changed, 231 insertions, 0 deletions
diff --git a/src/TNetLib/.svn/text-base/ObjFun.cc.svn-base b/src/TNetLib/.svn/text-base/ObjFun.cc.svn-base
new file mode 100644
index 0000000..c899fb1
--- /dev/null
+++ b/src/TNetLib/.svn/text-base/ObjFun.cc.svn-base
@@ -0,0 +1,231 @@
+
+#include "ObjFun.h"
+#include "Error.h"
+
+#include <limits>
+
+namespace TNet {
+
+
+ObjectiveFunction* ObjectiveFunction::Factory(ObjFunType type) {
+ ObjectiveFunction* ret = NULL;
+ switch(type) {
+ case MEAN_SQUARE_ERROR: ret = new MeanSquareError; break;
+ case CROSS_ENTROPY: ret = new CrossEntropy; break;
+ default: Error("Unknown ObjectiveFunction type");
+ }
+ return ret;
+}
+
+
+/*
+ * MeanSquareError
+ */
+void MeanSquareError::Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err) {
+
+ //check dimensions
+ assert(net_out.Rows() == target.Rows());
+ assert(net_out.Cols() == target.Cols());
+ if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) {
+ err->Init(net_out.Rows(),net_out.Cols());
+ }
+
+ //compute global gradient
+ err->Copy(net_out);
+ err->AddScaled(-1,target);
+
+ //compute loss function
+ double sum = 0;
+ for(size_t r=0; r<err->Rows(); r++) {
+ for(size_t c=0; c<err->Cols(); c++) {
+ BaseFloat val = (*err)(r,c);
+ sum += val*val;
+ }
+ }
+ error_ += sum/2.0;
+ frames_ += net_out.Rows();
+}
+
+
+std::string MeanSquareError::Report() {
+ std::stringstream ss;
+ ss << "Mse:" << error_ << " frames:" << frames_
+ << " err/frm:" << error_/frames_
+ << "\n";
+ return ss.str();
+}
+
+
+/*
+ * CrossEntropy
+ */
+
+///Find maximum in float array
+inline int FindMaxId(const BaseFloat* ptr, size_t N) {
+ BaseFloat mval = -1e20f;
+ int mid = -1;
+ for(size_t i=0; i<N; i++) {
+ if(ptr[i] > mval) {
+ mid = i; mval = ptr[i];
+ }
+ }
+ return mid;
+}
+
+
+void
+CrossEntropy::Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err)
+{
+ //check dimensions
+ assert(net_out.Rows() == target.Rows());
+ assert(net_out.Cols() == target.Cols());
+ if(err->Rows() != net_out.Rows() || err->Cols() != net_out.Cols()) {
+ err->Init(net_out.Rows(),net_out.Cols());
+ }
+
+ //allocate confunsion buffers
+ if(confusion_mode_ != NO_CONF) {
+ if(confusion_.Rows() != target.Cols() || confusion_.Cols() != target.Cols()) {
+ confusion_.Init(target.Cols(),target.Cols());
+ confusion_count_.Init(target.Cols());
+ diag_confusion_.Init(target.Cols());
+ }
+ }
+
+ //compute global gradient (assuming on softmax input)
+ err->Copy(net_out);
+ err->AddScaled(-1,target);
+
+ //collect max values
+ std::vector<size_t> max_target_id(target.Rows());
+ std::vector<size_t> max_netout_id(target.Rows());
+ //check correct classification
+ int corr = 0;
+ for(size_t r=0; r<net_out.Rows(); r++) {
+ int id_netout = FindMaxId(net_out[r].pData(),net_out.Cols());
+ int id_target = FindMaxId(target[r].pData(),target.Cols());
+ if(id_netout == id_target) corr++;
+ max_target_id[r] = id_target;//store the max value
+ max_netout_id[r] = id_netout;
+ }
+
+ //compute loss function
+ double sumerr = 0;
+ for(size_t r=0; r<net_out.Rows(); r++) {
+ if(target(r,max_target_id[r]) == 1.0) {
+ //pick the max value..., rest is zero
+ BaseFloat val = log(net_out(r,max_target_id[r]));
+ if(val < -1e10f) val = -1e10f;
+ sumerr += val;
+ } else {
+ //process whole posterior vect.
+ for(size_t c=0; c<net_out.Cols(); c++) {
+ if(target(r,c) != 0.0) {
+ BaseFloat val = target(r,c)*log(net_out(r,c));
+ if(val < -1e10f) val = -1e10f;
+ sumerr += val;
+ }
+ }
+ }
+ }
+
+ //accumulate confusuion network
+ if(confusion_mode_ != NO_CONF) {
+ for(size_t r=0; r<net_out.Rows(); r++) {
+ int id_target = max_target_id[r];
+ int id_netout = max_netout_id[r];
+ switch(confusion_mode_) {
+ case MAX_CONF:
+ confusion_(id_target,id_netout) += 1;
+ break;
+ case SOFT_CONF:
+ confusion_[id_target].Add(net_out[r]);
+ break;
+ case DIAG_MAX_CONF:
+ diag_confusion_[id_target] += ((id_target==id_netout)?1:0);
+ break;
+ case DIAG_SOFT_CONF:
+ diag_confusion_[id_target] += net_out[r][id_target];
+ break;
+ default:
+ KALDI_ERR << "unknown confusion type" << confusion_mode_;
+ }
+ confusion_count_[id_target] += 1;
+ }
+ }
+
+ error_ -= sumerr;
+ frames_ += net_out.Rows();
+ corr_ += corr;
+}
+
+
+std::string CrossEntropy::Report() {
+ std::stringstream ss;
+ ss << "Xent:" << error_ << " frames:" << frames_
+ << " err/frm:" << error_/frames_
+ << " correct[" << 100.0*corr_/frames_ << "%]"
+ << "\n";
+
+ if(confusion_mode_ != NO_CONF) {
+ //read class tags
+ std::vector<std::string> tag;
+ {
+ std::ifstream ifs(output_label_map_);
+ assert(ifs.good());
+ std::string str;
+ while(!ifs.eof()) {
+ ifs >> str;
+ tag.push_back(str);
+ }
+ }
+ assert(confusion_count_.Dim() <= tag.size());
+
+ //print confusion matrix
+ if(confusion_mode_ == MAX_CONF || confusion_mode_ == SOFT_CONF) {
+ ss << "Row:label Col:hyp\n" << confusion_ << "\n";
+ }
+
+ //***print per-target accuracies
+ for(int i=0; i<confusion_count_.Dim(); i++) {
+ //get the numerator
+ BaseFloat numerator = 0.0;
+ switch (confusion_mode_) {
+ case MAX_CONF: case SOFT_CONF:
+ numerator = confusion_[i][i];
+ break;
+ case DIAG_MAX_CONF: case DIAG_SOFT_CONF:
+ numerator = diag_confusion_[i];
+ break;
+ default:
+ KALDI_ERR << "Usupported confusion mode:" << confusion_mode_;
+ }
+ //add line to report
+ ss << std::setw(30) << tag[i] << " "
+ << std::setw(10) << 100.0*numerator/confusion_count_[i] << "%"
+ << " [" << numerator << "/" << confusion_count_[i] << "]\n";
+ } //***print per-target accuracies
+ }// != NO_CONF
+
+ return ss.str();
+}
+
+
+void CrossEntropy::MergeStats(const ObjectiveFunction& inst) {
+ const CrossEntropy& xent = dynamic_cast<const CrossEntropy&>(inst);
+ frames_ += xent.frames_; error_ += xent.error_; corr_ += xent.corr_;
+ //sum the confustion statistics
+ if(confusion_mode_ != NO_CONF) {
+ if(confusion_.Rows() != xent.confusion_.Rows()) {
+ confusion_.Init(xent.confusion_.Rows(),xent.confusion_.Cols());
+ confusion_count_.Init(xent.confusion_count_.Dim());
+ diag_confusion_.Init(xent.diag_confusion_.Dim());
+ }
+ confusion_.Add(xent.confusion_);
+ confusion_count_.Add(xent.confusion_count_);
+ diag_confusion_.Add(xent.diag_confusion_);
+ }
+}
+
+
+} // namespace TNet