diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuObjectiveFunction.cc | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/CuTNetLib/cuObjectiveFunction.cc')
-rw-r--r-- | src/CuTNetLib/cuObjectiveFunction.cc | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuObjectiveFunction.cc b/src/CuTNetLib/cuObjectiveFunction.cc new file mode 100644 index 0000000..e2b0a1d --- /dev/null +++ b/src/CuTNetLib/cuObjectiveFunction.cc @@ -0,0 +1,87 @@ + +#include "cuObjectiveFunction.h" + +#include "Error.h" +#include "cumath.h" + + +namespace TNet +{ + + + + CuObjectiveFunction* + CuObjectiveFunction:: + Factory(ObjFunType type) { + CuObjectiveFunction* ret = NULL; + switch(type) { + case MEAN_SQUARE_ERROR: ret = new CuMeanSquareError; break; + case CROSS_ENTROPY: ret = new CuCrossEntropy; break; + default: Error("Unknown ObjFun type"); + } + return ret; + } + + + void + CuMeanSquareError:: + Evaluate(const CuMatrix<BaseFloat>& rNetOutput, const CuMatrix<BaseFloat>& rDesired, CuMatrix<BaseFloat>& rNetError) + { + //get the global error + rNetError.CopyFrom(rNetOutput); + rNetError.AddScaled(-1.0,rDesired,1.0); + + //calculate the MSE + mAuxMat.CopyFrom(rNetError); + mAuxMat.MulElem(mAuxMat); + + mAuxVec.Init(mAuxMat.Cols()); + mAuxVec.AddColSum(1.0,mAuxMat,0.0); + mAuxVec.CopyTo(mAuxVecHost); + + mError += mAuxVecHost.Sum(); + + //count the frames + mFrames += rNetError.Rows(); + } + + void + CuCrossEntropy:: + Evaluate(const CuMatrix<BaseFloat>& rNetOutput, const CuMatrix<BaseFloat>& rDesired, CuMatrix<BaseFloat>& rNetError) + { + if(rDesired.Cols() != rNetOutput.Cols()) { + std::ostringstream os; + os << "Non-matching dimensions of network output with training targets!!!" + << " Netoutput:" << rNetOutput.Cols() + << " Targets:" << rDesired.Cols(); + Error(os.str()); + } + + //get the global error + //dXent/dSoftmax_in = y-d + rNetError.CopyFrom(rNetOutput); + rNetError.AddScaled(-1.0,rDesired,1.0); + + //check classification + mClassifyVec.Init(rNetOutput.Rows()); + CuMath<BaseFloat>::CheckClass(rNetOutput,rDesired,mClassifyVec); + mClassifyVec.CopyTo(mClassifyVecHost); + mCorrect += mClassifyVecHost.Sum(); + + //calculate Xent + mAuxMat.CopyFrom(rNetOutput); + mAuxMat.LogElem(); + mAuxMat.MulElem(rDesired); + + mAuxVec.Init(mAuxMat.Cols()); + mAuxVec.AddColSum(-1.0,mAuxMat,0.0); + mAuxVec.CopyTo(mAuxVecHost); + + mError += mAuxVecHost.Sum(); + + //count the frames + mFrames += rNetError.Rows(); + } + + +} // namespace TNet |