summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/.svn/text-base/cuRbmSparse.h.svn-base
diff options
context:
space:
mode:
Diffstat (limited to 'src/CuTNetLib/.svn/text-base/cuRbmSparse.h.svn-base')
-rw-r--r--src/CuTNetLib/.svn/text-base/cuRbmSparse.h.svn-base134
1 files changed, 134 insertions, 0 deletions
diff --git a/src/CuTNetLib/.svn/text-base/cuRbmSparse.h.svn-base b/src/CuTNetLib/.svn/text-base/cuRbmSparse.h.svn-base
new file mode 100644
index 0000000..9d7e304
--- /dev/null
+++ b/src/CuTNetLib/.svn/text-base/cuRbmSparse.h.svn-base
@@ -0,0 +1,134 @@
+#ifndef _CU_RBM_SPARSE_H_
+#define _CU_RBM_SPARSE_H_
+
+
+#include "cuComponent.h"
+#include "cumatrix.h"
+#include "cuRbm.h"
+
+
+#include "Matrix.h"
+#include "Vector.h"
+
+
+namespace TNet {
+
+ class CuRbmSparse : public CuRbmBase
+ {
+ public:
+
+ CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred);
+ ~CuRbmSparse();
+
+ ComponentType GetType() const;
+ const char* GetName() const;
+
+ //CuUpdatableComponent API
+ void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+
+ void Update();
+
+ //RBM training API
+ void Propagate(const CuMatrix<BaseFloat>& visProbs, CuMatrix<BaseFloat>& hidProbs);
+ void Reconstruct(const CuMatrix<BaseFloat>& hidState, CuMatrix<BaseFloat>& visProbs);
+ void RbmUpdate(const CuMatrix<BaseFloat>& pos_vis, const CuMatrix<BaseFloat>& pos_hid, const CuMatrix<BaseFloat>& neg_vis, const CuMatrix<BaseFloat>& neg_hid);
+
+ RbmUnitType VisType()
+ { return mVisType; }
+
+ RbmUnitType HidType()
+ { return mHidType; }
+
+ //static void BinarizeProbs(const CuMatrix<BaseFloat>& probs, CuMatrix<BaseFloat>& states);
+
+ //I/O
+ void ReadFromStream(std::istream& rIn);
+ void WriteToStream(std::ostream& rOut);
+
+ protected:
+ CuMatrix<BaseFloat> mVisHid; ///< Matrix with neuron weights
+ CuVector<BaseFloat> mVisBias; ///< Vector with biases
+ CuVector<BaseFloat> mHidBias; ///< Vector with biases
+
+ CuMatrix<BaseFloat> mVisHidCorrection; ///< Matrix for linearity updates
+ CuVector<BaseFloat> mVisBiasCorrection; ///< Vector for bias updates
+ CuVector<BaseFloat> mHidBiasCorrection; ///< Vector for bias updates
+
+ CuMatrix<BaseFloat> mBackpropErrBuf;
+
+ RbmUnitType mVisType;
+ RbmUnitType mHidType;
+
+ ////// sparsity
+ BaseFloat mSparsityPrior; ///< sparsity target (unit activity prior)
+ BaseFloat mLambda; ///< exponential decay factor for q (observed probability of unit to be active)
+ BaseFloat mSparsityCost; ///< sparsity cost coef.
+
+ CuVector<BaseFloat> mSparsityQ;
+ CuVector<BaseFloat> mSparsityQCurrent;
+ CuVector<BaseFloat> mVisMean; ///< buffer for mean visible
+
+ };
+
+
+
+
+ ////////////////////////////////////////////////////////////////////////////
+ // INLINE FUNCTIONS
+ // CuRbmSparse::
+ inline
+ CuRbmSparse::
+ CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred)
+ : CuRbmBase(nInputs, nOutputs, pPred),
+ mVisHid(nInputs,nOutputs),
+ mVisBias(nInputs), mHidBias(nOutputs),
+ mVisHidCorrection(nInputs,nOutputs),
+ mVisBiasCorrection(nInputs), mHidBiasCorrection(nOutputs),
+ mBackpropErrBuf(),
+ mVisType(BERNOULLI),
+ mHidType(BERNOULLI),
+
+ mSparsityPrior(0.0001),
+ mLambda(0.95),
+ mSparsityCost(1e-7),
+ mSparsityQ(nOutputs),
+ mSparsityQCurrent(nOutputs),
+ mVisMean(nInputs)
+ {
+ mVisHidCorrection.SetConst(0.0);
+ mVisBiasCorrection.SetConst(0.0);
+ mHidBiasCorrection.SetConst(0.0);
+
+ mSparsityQ.SetConst(mSparsityPrior);
+ mSparsityQCurrent.SetConst(0.0);
+ mVisMean.SetConst(0.0);
+ }
+
+
+ inline
+ CuRbmSparse::
+ ~CuRbmSparse()
+ { }
+
+ inline CuComponent::ComponentType
+ CuRbmSparse::
+ GetType() const
+ {
+ return CuComponent::RBM_SPARSE;
+ }
+
+ inline const char*
+ CuRbmSparse::
+ GetName() const
+ {
+ return "<rbmsparse>";
+ }
+
+
+
+} //namespace
+
+
+
+#endif