summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuObjectiveFunction.cc
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
commitcccccbf6cca94a3eaf813b4468453160e91c332b (patch)
tree23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuObjectiveFunction.cc
downloadtnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip
First commit
Diffstat (limited to 'src/CuTNetLib/cuObjectiveFunction.cc')
-rw-r--r--src/CuTNetLib/cuObjectiveFunction.cc87
1 files changed, 87 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuObjectiveFunction.cc b/src/CuTNetLib/cuObjectiveFunction.cc
new file mode 100644
index 0000000..e2b0a1d
--- /dev/null
+++ b/src/CuTNetLib/cuObjectiveFunction.cc
@@ -0,0 +1,87 @@
+
+#include "cuObjectiveFunction.h"
+
+#include "Error.h"
+#include "cumath.h"
+
+
+namespace TNet
+{
+
+
+
+ CuObjectiveFunction*
+ CuObjectiveFunction::
+ Factory(ObjFunType type) {
+ CuObjectiveFunction* ret = NULL;
+ switch(type) {
+ case MEAN_SQUARE_ERROR: ret = new CuMeanSquareError; break;
+ case CROSS_ENTROPY: ret = new CuCrossEntropy; break;
+ default: Error("Unknown ObjFun type");
+ }
+ return ret;
+ }
+
+
+ void
+ CuMeanSquareError::
+ Evaluate(const CuMatrix<BaseFloat>& rNetOutput, const CuMatrix<BaseFloat>& rDesired, CuMatrix<BaseFloat>& rNetError)
+ {
+ //get the global error
+ rNetError.CopyFrom(rNetOutput);
+ rNetError.AddScaled(-1.0,rDesired,1.0);
+
+ //calculate the MSE
+ mAuxMat.CopyFrom(rNetError);
+ mAuxMat.MulElem(mAuxMat);
+
+ mAuxVec.Init(mAuxMat.Cols());
+ mAuxVec.AddColSum(1.0,mAuxMat,0.0);
+ mAuxVec.CopyTo(mAuxVecHost);
+
+ mError += mAuxVecHost.Sum();
+
+ //count the frames
+ mFrames += rNetError.Rows();
+ }
+
+ void
+ CuCrossEntropy::
+ Evaluate(const CuMatrix<BaseFloat>& rNetOutput, const CuMatrix<BaseFloat>& rDesired, CuMatrix<BaseFloat>& rNetError)
+ {
+ if(rDesired.Cols() != rNetOutput.Cols()) {
+ std::ostringstream os;
+ os << "Non-matching dimensions of network output with training targets!!!"
+ << " Netoutput:" << rNetOutput.Cols()
+ << " Targets:" << rDesired.Cols();
+ Error(os.str());
+ }
+
+ //get the global error
+ //dXent/dSoftmax_in = y-d
+ rNetError.CopyFrom(rNetOutput);
+ rNetError.AddScaled(-1.0,rDesired,1.0);
+
+ //check classification
+ mClassifyVec.Init(rNetOutput.Rows());
+ CuMath<BaseFloat>::CheckClass(rNetOutput,rDesired,mClassifyVec);
+ mClassifyVec.CopyTo(mClassifyVecHost);
+ mCorrect += mClassifyVecHost.Sum();
+
+ //calculate Xent
+ mAuxMat.CopyFrom(rNetOutput);
+ mAuxMat.LogElem();
+ mAuxMat.MulElem(rDesired);
+
+ mAuxVec.Init(mAuxMat.Cols());
+ mAuxVec.AddColSum(-1.0,mAuxMat,0.0);
+ mAuxVec.CopyTo(mAuxVecHost);
+
+ mError += mAuxVecHost.Sum();
+
+ //count the frames
+ mFrames += rNetError.Rows();
+ }
+
+
+} // namespace TNet