#ifndef _CU_RBM_H_ #define _CU_RBM_H_ #include "cuComponent.h" #include "cumatrix.h" #include "Matrix.h" #include "Vector.h" namespace TNet { class CuRbmBase : public CuUpdatableComponent { public: typedef enum { BERNOULLI, GAUSSIAN } RbmUnitType; CuRbmBase(size_t nInputs, size_t nOutputs, CuComponent *pPred) : CuUpdatableComponent(nInputs, nOutputs, pPred) { } virtual void Propagate( const CuMatrix& visProbs, CuMatrix& hidProbs ) = 0; virtual void Reconstruct( const CuMatrix& hidState, CuMatrix& visProbs ) = 0; virtual void RbmUpdate( const CuMatrix& pos_vis, const CuMatrix& pos_hid, const CuMatrix& neg_vis, const CuMatrix& neg_hid ) = 0; virtual RbmUnitType VisType() = 0; virtual RbmUnitType HidType() = 0; }; class CuRbm : public CuRbmBase { public: CuRbm(size_t nInputs, size_t nOutputs, CuComponent *pPred); ~CuRbm(); ComponentType GetType() const; const char* GetName() const; //CuUpdatableComponent API void PropagateFnc(const CuMatrix& X, CuMatrix& Y); void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y); void Update(); //RBM training API void Propagate(const CuMatrix& visProbs, CuMatrix& hidProbs); void Reconstruct(const CuMatrix& hidState, CuMatrix& visProbs); void RbmUpdate(const CuMatrix& pos_vis, const CuMatrix& pos_hid, const CuMatrix& neg_vis, const CuMatrix& neg_hid); RbmUnitType VisType() { return mVisType; } RbmUnitType HidType() { return mHidType; } //static void BinarizeProbs(const CuMatrix& probs, CuMatrix& states); //I/O void ReadFromStream(std::istream& rIn); void WriteToStream(std::ostream& rOut); protected: CuMatrix mVisHid; ///< Matrix with neuron weights CuVector mVisBias; ///< Vector with biases CuVector mHidBias; ///< Vector with biases CuMatrix mVisHidCorrection; ///< Matrix for linearity updates CuVector mVisBiasCorrection; ///< Vector for bias updates CuVector mHidBiasCorrection; ///< Vector for bias updates CuMatrix mBackpropErrBuf; RbmUnitType mVisType; RbmUnitType mHidType; }; //////////////////////////////////////////////////////////////////////////// // INLINE FUNCTIONS // CuRbm:: inline CuRbm:: CuRbm(size_t nInputs, size_t nOutputs, CuComponent *pPred) : CuRbmBase(nInputs, nOutputs, pPred), mVisHid(nInputs,nOutputs), mVisBias(nInputs), mHidBias(nOutputs), mVisHidCorrection(nInputs,nOutputs), mVisBiasCorrection(nInputs), mHidBiasCorrection(nOutputs), mBackpropErrBuf(), mVisType(BERNOULLI), mHidType(BERNOULLI) { mVisHidCorrection.SetConst(0.0); mVisBiasCorrection.SetConst(0.0); mHidBiasCorrection.SetConst(0.0); } inline CuRbm:: ~CuRbm() { } inline CuComponent::ComponentType CuRbm:: GetType() const { return CuComponent::RBM; } inline const char* CuRbm:: GetName() const { return ""; } } //namespace #endif