diff options
Diffstat (limited to 'src/CuTNetLib/cuObjectiveFunction.h')
-rw-r--r-- | src/CuTNetLib/cuObjectiveFunction.h | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuObjectiveFunction.h b/src/CuTNetLib/cuObjectiveFunction.h new file mode 100644 index 0000000..4dd0c32 --- /dev/null +++ b/src/CuTNetLib/cuObjectiveFunction.h @@ -0,0 +1,185 @@ +#ifndef _CUOBJ_FUN_I_ +#define _CUOBJ_FUN_I_ + +#include <cassert> +#include <limits> +#include <cmath> +#include <sstream> + +#include "Vector.h" +#include "cuvector.h" +#include "cumatrix.h" + +/** + * \file cuObjectiveFunction.h + * \brief Objective Functions used to compare the model and data + */ + +/** + * \defgroup CuModelObj CuNN Objective Functions + * \ingroup CuNNComp + */ + +namespace TNet +{ + + + /** + * \brief General interface for objective functions + */ + class CuObjectiveFunction + { + public: + /// Enum with objective function types + typedef enum { + OBJ_FUN_I = 0x0300, + MEAN_SQUARE_ERROR, + CROSS_ENTROPY, + } ObjFunType; + + /// Factory for creating objective function instances + static CuObjectiveFunction* Factory(ObjFunType type); + + ////////////////////////////////////////////////////////////// + // Interface specification + public: + CuObjectiveFunction() + { } + + virtual ~CuObjectiveFunction() + { } + + virtual ObjFunType GetTypeId() = 0; + virtual const char* GetTypeLabel() = 0; + + /// evaluates the data, calculate global error + /// \param[in] rNetOutput CuNN output as generated by model + /// \param[in] rDesired Desired output specified by data + /// \param[out] rNetError Derivative of the Energy Function + virtual void Evaluate(const CuMatrix<BaseFloat>& rNetOutput, const CuMatrix<BaseFloat>& rDesired, CuMatrix<BaseFloat>& rNetError) = 0; + + ///get the average per frame error + virtual double GetError() = 0; + ///the number of processed frames + virtual size_t GetFrames() = 0; + ///report the error to std::cout + virtual std::string Report() = 0; + }; + + + + + /** + * \brief Means square error, useful for autoencoders, RBMs et al. + * + * \ingroup CuModelObj + * Calculate: \f[ ||\vec{ModelOutput}-\vec{Label}||^2 \f] + */ + class CuMeanSquareError : public CuObjectiveFunction + { + public: + CuMeanSquareError() + : mError(0), mFrames(0) + { } + + virtual ~CuMeanSquareError() + { } + + ObjFunType GetTypeId() + { return CuObjectiveFunction::MEAN_SQUARE_ERROR; } + + const char* GetTypeLabel() + { return "<mean_square_error>"; } + + void Evaluate(const CuMatrix<BaseFloat>& rNetOutput, const CuMatrix<BaseFloat>& rDesired, CuMatrix<BaseFloat>& rNetError); + + double GetError() + { return mError; } + + size_t GetFrames() + { return mFrames; } + + std::string Report() + { + std::ostringstream ss; + ss << "Mse:" << mError << " frames:" << mFrames + << " err/frm:" << mError/mFrames << "\n"; + return ss.str(); + } + + private: + double mError; + size_t mFrames; + + CuMatrix<BaseFloat> mAuxMat; + CuVector<BaseFloat> mAuxVec; + Vector<BaseFloat> mAuxVecHost; + + }; + + + /** + * \brief Cross entropy, it assumes desired vectors as output values + * + * \ingroup CuModelObj + * Calculate: \f[ -\ln(\vec{ModelOutput}) \cdot \vec{Label} \f] + */ + class CuCrossEntropy : public CuObjectiveFunction + { + public: + CuCrossEntropy() + : mError(0), mFrames(0), mCorrect(0) + { } + + ~CuCrossEntropy() + { } + + ObjFunType GetTypeId() + { return CuObjectiveFunction::CROSS_ENTROPY; } + + const char* GetTypeLabel() + { return "<cross_entropy>"; } + + void Evaluate(const CuMatrix<BaseFloat>& rNetOutput, const CuMatrix<BaseFloat>& rDesired, CuMatrix<BaseFloat>& rNetError); + + double GetError() + { return mError; } + + size_t GetFrames() + { return mFrames; } + + std::string Report() + { + std::ostringstream ss; + //for compatibility with SNet + //ss << " correct: >> " << 100.0*mCorrect/mFrames << "% <<\n"; + + //current new format... + ss << "Xent:" << mError << " frames:" << mFrames + << " err/frm:" << mError/mFrames + << " correct[" << 100.0*mCorrect/mFrames << "%]" + << "\n"; + return ss.str(); + } + + private: + double mError; + size_t mFrames; + size_t mCorrect; + + CuMatrix<BaseFloat> mAuxMat; + CuVector<BaseFloat> mAuxVec; + Vector<BaseFloat> mAuxVecHost; + + CuVector<int> mClassifyVec; + Vector<int> mClassifyVecHost; + }; + + + + + +} //namespace TNet + + +#endif |