summaryrefslogtreecommitdiff
path: root/src/TNetLib/.svn/text-base/ObjFun.h.svn-base
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
commitcccccbf6cca94a3eaf813b4468453160e91c332b (patch)
tree23418cb73a10ae3b0688681a7f0ba9b06424583e /src/TNetLib/.svn/text-base/ObjFun.h.svn-base
downloadtnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip
First commit
Diffstat (limited to 'src/TNetLib/.svn/text-base/ObjFun.h.svn-base')
-rw-r--r--src/TNetLib/.svn/text-base/ObjFun.h.svn-base160
1 files changed, 160 insertions, 0 deletions
diff --git a/src/TNetLib/.svn/text-base/ObjFun.h.svn-base b/src/TNetLib/.svn/text-base/ObjFun.h.svn-base
new file mode 100644
index 0000000..c458340
--- /dev/null
+++ b/src/TNetLib/.svn/text-base/ObjFun.h.svn-base
@@ -0,0 +1,160 @@
+#ifndef _TNET_OBJ_FUN_H
+#define _TNET_OBJ_FUN_H
+
+#include <cassert>
+#include <limits>
+#include <cmath>
+
+#include "Matrix.h"
+#include "Vector.h"
+
+namespace TNet {
+
+ /**
+ * General interface for objective functions
+ */
+ class ObjectiveFunction
+ {
+ public:
+ /// Enum with objective function types
+ typedef enum {
+ OBJ_FUN_I = 0x0300,
+ MEAN_SQUARE_ERROR,
+ CROSS_ENTROPY,
+ } ObjFunType;
+
+ public:
+ /// Factory for creating objective function instances
+ static ObjectiveFunction* Factory(ObjFunType type);
+
+ //////////////////////////////////////////////////////////////
+ // Interface specification
+ protected:
+ ObjectiveFunction() { }; /// constructor
+ public:
+ virtual ~ObjectiveFunction() { }; /// destructor
+
+ virtual ObjFunType GetType() = 0;
+ virtual const char* GetName() = 0;
+ virtual ObjectiveFunction* Clone() = 0;
+
+ ///calculate error of network output
+ virtual void Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err) = 0;
+
+ ///get the accumulated error
+ virtual double GetError() = 0;
+ ///the number of processed frames
+ virtual size_t GetFrames() = 0;
+
+ ///report the error to string
+ virtual std::string Report() = 0;
+
+ ///sum the frame counts from more instances
+ virtual void MergeStats(const ObjectiveFunction& inst) = 0;
+ };
+
+
+
+ /**
+ * Mean square error function
+ */
+ class MeanSquareError : public ObjectiveFunction
+ {
+ public:
+ MeanSquareError()
+ : ObjectiveFunction(), frames_(0), error_(0)
+ { }
+
+ ~MeanSquareError()
+ { }
+
+ ObjFunType GetType()
+ { return MEAN_SQUARE_ERROR; }
+
+ const char* GetName()
+ { return "<MeanSquareError>"; }
+
+ ObjectiveFunction* Clone()
+ { return new MeanSquareError(*this); }
+
+ void Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err);
+
+ size_t GetFrames()
+ { return frames_; }
+
+ double GetError()
+ { return error_; }
+
+ std::string Report();
+
+ void MergeStats(const ObjectiveFunction& inst) {
+ const MeanSquareError& mse = dynamic_cast<const MeanSquareError&>(inst);
+ frames_ += mse.frames_; error_ += mse.error_;
+ }
+
+ private:
+ size_t frames_;
+ double error_;
+
+ };
+
+
+ /**
+ * Cross entropy error function
+ */
+ class CrossEntropy : public ObjectiveFunction
+ {
+ public:
+ enum ConfusionMode { NO_CONF=0, MAX_CONF, SOFT_CONF, DIAG_MAX_CONF, DIAG_SOFT_CONF };
+
+ public:
+ CrossEntropy()
+ : ObjectiveFunction(), frames_(0), error_(0), corr_(0), confusion_mode_(NO_CONF), output_label_map_(NULL)
+ { }
+
+ ~CrossEntropy()
+ { }
+
+ ObjFunType GetType()
+ { return CROSS_ENTROPY; }
+
+ const char* GetName()
+ { return "<cross_entropy>"; }
+
+ ObjectiveFunction* Clone()
+ { return new CrossEntropy(*this); }
+
+ void Evaluate(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* err);
+
+ size_t GetFrames()
+ { return frames_; }
+
+ double GetError()
+ { return error_; }
+
+ void SetConfusionMode(enum ConfusionMode m)
+ { confusion_mode_ = m; }
+
+ void SetOutputLabelMap(const char* map)
+ { output_label_map_ = map; }
+
+ std::string Report();
+
+ void MergeStats(const ObjectiveFunction& inst);
+ private:
+ size_t frames_;
+ double error_;
+ size_t corr_;
+
+ ConfusionMode confusion_mode_;
+ Matrix<float> confusion_;
+ Vector<int> confusion_count_;
+ Vector<double> diag_confusion_;
+ const char* output_label_map_;
+ };
+
+
+} //namespace TNet
+
+
+#endif