diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuRbmSparse.h | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/CuTNetLib/cuRbmSparse.h')
-rw-r--r-- | src/CuTNetLib/cuRbmSparse.h | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuRbmSparse.h b/src/CuTNetLib/cuRbmSparse.h new file mode 100644 index 0000000..9d7e304 --- /dev/null +++ b/src/CuTNetLib/cuRbmSparse.h @@ -0,0 +1,134 @@ +#ifndef _CU_RBM_SPARSE_H_ +#define _CU_RBM_SPARSE_H_ + + +#include "cuComponent.h" +#include "cumatrix.h" +#include "cuRbm.h" + + +#include "Matrix.h" +#include "Vector.h" + + +namespace TNet { + + class CuRbmSparse : public CuRbmBase + { + public: + + CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred); + ~CuRbmSparse(); + + ComponentType GetType() const; + const char* GetName() const; + + //CuUpdatableComponent API + void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + + void Update(); + + //RBM training API + void Propagate(const CuMatrix<BaseFloat>& visProbs, CuMatrix<BaseFloat>& hidProbs); + void Reconstruct(const CuMatrix<BaseFloat>& hidState, CuMatrix<BaseFloat>& visProbs); + void RbmUpdate(const CuMatrix<BaseFloat>& pos_vis, const CuMatrix<BaseFloat>& pos_hid, const CuMatrix<BaseFloat>& neg_vis, const CuMatrix<BaseFloat>& neg_hid); + + RbmUnitType VisType() + { return mVisType; } + + RbmUnitType HidType() + { return mHidType; } + + //static void BinarizeProbs(const CuMatrix<BaseFloat>& probs, CuMatrix<BaseFloat>& states); + + //I/O + void ReadFromStream(std::istream& rIn); + void WriteToStream(std::ostream& rOut); + + protected: + CuMatrix<BaseFloat> mVisHid; ///< Matrix with neuron weights + CuVector<BaseFloat> mVisBias; ///< Vector with biases + CuVector<BaseFloat> mHidBias; ///< Vector with biases + + CuMatrix<BaseFloat> mVisHidCorrection; ///< Matrix for linearity updates + CuVector<BaseFloat> mVisBiasCorrection; ///< Vector for bias updates + CuVector<BaseFloat> mHidBiasCorrection; ///< Vector for bias updates + + CuMatrix<BaseFloat> mBackpropErrBuf; + + RbmUnitType mVisType; + RbmUnitType mHidType; + + ////// sparsity + BaseFloat mSparsityPrior; ///< sparsity target (unit activity prior) + BaseFloat mLambda; ///< exponential decay factor for q (observed probability of unit to be active) + BaseFloat mSparsityCost; ///< sparsity cost coef. + + CuVector<BaseFloat> mSparsityQ; + CuVector<BaseFloat> mSparsityQCurrent; + CuVector<BaseFloat> mVisMean; ///< buffer for mean visible + + }; + + + + + //////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // CuRbmSparse:: + inline + CuRbmSparse:: + CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuRbmBase(nInputs, nOutputs, pPred), + mVisHid(nInputs,nOutputs), + mVisBias(nInputs), mHidBias(nOutputs), + mVisHidCorrection(nInputs,nOutputs), + mVisBiasCorrection(nInputs), mHidBiasCorrection(nOutputs), + mBackpropErrBuf(), + mVisType(BERNOULLI), + mHidType(BERNOULLI), + + mSparsityPrior(0.0001), + mLambda(0.95), + mSparsityCost(1e-7), + mSparsityQ(nOutputs), + mSparsityQCurrent(nOutputs), + mVisMean(nInputs) + { + mVisHidCorrection.SetConst(0.0); + mVisBiasCorrection.SetConst(0.0); + mHidBiasCorrection.SetConst(0.0); + + mSparsityQ.SetConst(mSparsityPrior); + mSparsityQCurrent.SetConst(0.0); + mVisMean.SetConst(0.0); + } + + + inline + CuRbmSparse:: + ~CuRbmSparse() + { } + + inline CuComponent::ComponentType + CuRbmSparse:: + GetType() const + { + return CuComponent::RBM_SPARSE; + } + + inline const char* + CuRbmSparse:: + GetName() const + { + return "<rbmsparse>"; + } + + + +} //namespace + + + +#endif |