From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- src/CuTNetLib/cuRbm.h | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 src/CuTNetLib/cuRbm.h (limited to 'src/CuTNetLib/cuRbm.h') diff --git a/src/CuTNetLib/cuRbm.h b/src/CuTNetLib/cuRbm.h new file mode 100644 index 0000000..c1e984b --- /dev/null +++ b/src/CuTNetLib/cuRbm.h @@ -0,0 +1,146 @@ +#ifndef _CU_RBM_H_ +#define _CU_RBM_H_ + + +#include "cuComponent.h" +#include "cumatrix.h" + + +#include "Matrix.h" +#include "Vector.h" + + +namespace TNet { + + class CuRbmBase : public CuUpdatableComponent + { + public: + typedef enum { + BERNOULLI, + GAUSSIAN + } RbmUnitType; + + CuRbmBase(size_t nInputs, size_t nOutputs, CuComponent *pPred) : + CuUpdatableComponent(nInputs, nOutputs, pPred) + { } + + + virtual void Propagate( + const CuMatrix& visProbs, + CuMatrix& hidProbs + ) = 0; + virtual void Reconstruct( + const CuMatrix& hidState, + CuMatrix& visProbs + ) = 0; + virtual void RbmUpdate( + const CuMatrix& pos_vis, + const CuMatrix& pos_hid, + const CuMatrix& neg_vis, + const CuMatrix& neg_hid + ) = 0; + + virtual RbmUnitType VisType() = 0; + virtual RbmUnitType HidType() = 0; + }; + + + class CuRbm : public CuRbmBase + { + public: + + CuRbm(size_t nInputs, size_t nOutputs, CuComponent *pPred); + ~CuRbm(); + + ComponentType GetType() const; + const char* GetName() const; + + //CuUpdatableComponent API + void PropagateFnc(const CuMatrix& X, CuMatrix& Y); + void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y); + + void Update(); + + //RBM training API + void Propagate(const CuMatrix& visProbs, CuMatrix& hidProbs); + void Reconstruct(const CuMatrix& hidState, CuMatrix& visProbs); + void RbmUpdate(const CuMatrix& pos_vis, const CuMatrix& pos_hid, const CuMatrix& neg_vis, const CuMatrix& neg_hid); + + RbmUnitType VisType() + { return mVisType; } + + RbmUnitType HidType() + { return mHidType; } + + //static void BinarizeProbs(const CuMatrix& probs, CuMatrix& states); + + //I/O + void ReadFromStream(std::istream& rIn); + void WriteToStream(std::ostream& rOut); + + protected: + CuMatrix mVisHid; ///< Matrix with neuron weights + CuVector mVisBias; ///< Vector with biases + CuVector mHidBias; ///< Vector with biases + + CuMatrix mVisHidCorrection; ///< Matrix for linearity updates + CuVector mVisBiasCorrection; ///< Vector for bias updates + CuVector mHidBiasCorrection; ///< Vector for bias updates + + CuMatrix mBackpropErrBuf; + + RbmUnitType mVisType; + RbmUnitType mHidType; + + }; + + + + + //////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // CuRbm:: + inline + CuRbm:: + CuRbm(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuRbmBase(nInputs, nOutputs, pPred), + mVisHid(nInputs,nOutputs), + mVisBias(nInputs), mHidBias(nOutputs), + mVisHidCorrection(nInputs,nOutputs), + mVisBiasCorrection(nInputs), mHidBiasCorrection(nOutputs), + mBackpropErrBuf(), + mVisType(BERNOULLI), + mHidType(BERNOULLI) + { + mVisHidCorrection.SetConst(0.0); + mVisBiasCorrection.SetConst(0.0); + mHidBiasCorrection.SetConst(0.0); + } + + + inline + CuRbm:: + ~CuRbm() + { } + + inline CuComponent::ComponentType + CuRbm:: + GetType() const + { + return CuComponent::RBM; + } + + inline const char* + CuRbm:: + GetName() const + { + return ""; + } + + + +} //namespace + + + +#endif -- cgit v1.2.3-70-g09d2