summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/.svn/text-base/cuRbmSparse.h.svn-base
blob: 9d7e304f7713abb28e82fb2191ad72c081e349f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#ifndef _CU_RBM_SPARSE_H_
#define _CU_RBM_SPARSE_H_


#include "cuComponent.h"
#include "cumatrix.h"
#include "cuRbm.h"


#include "Matrix.h"
#include "Vector.h"


namespace TNet {

  class CuRbmSparse : public CuRbmBase
  {
    public:

      CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      ~CuRbmSparse();  
      
      ComponentType GetType() const;
      const char* GetName() const;

      //CuUpdatableComponent API
      void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
      void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);

      void Update();

      //RBM training API
      void Propagate(const CuMatrix<BaseFloat>& visProbs, CuMatrix<BaseFloat>& hidProbs);
      void Reconstruct(const CuMatrix<BaseFloat>& hidState, CuMatrix<BaseFloat>& visProbs);
      void RbmUpdate(const CuMatrix<BaseFloat>& pos_vis, const CuMatrix<BaseFloat>& pos_hid, const CuMatrix<BaseFloat>& neg_vis, const CuMatrix<BaseFloat>& neg_hid);

      RbmUnitType VisType()
      { return mVisType; }

      RbmUnitType HidType()
      { return mHidType; }

      //static void BinarizeProbs(const CuMatrix<BaseFloat>& probs, CuMatrix<BaseFloat>& states);

      //I/O
      void ReadFromStream(std::istream& rIn);
      void WriteToStream(std::ostream& rOut);

    protected:
      CuMatrix<BaseFloat> mVisHid;  ///< Matrix with neuron weights
      CuVector<BaseFloat> mVisBias;       ///< Vector with biases
      CuVector<BaseFloat> mHidBias;       ///< Vector with biases

      CuMatrix<BaseFloat> mVisHidCorrection; ///< Matrix for linearity updates
      CuVector<BaseFloat> mVisBiasCorrection;      ///< Vector for bias updates
      CuVector<BaseFloat> mHidBiasCorrection;      ///< Vector for bias updates

      CuMatrix<BaseFloat> mBackpropErrBuf;

      RbmUnitType mVisType;
      RbmUnitType mHidType;

      ////// sparsity 
      BaseFloat mSparsityPrior; ///< sparsity target (unit activity prior)
      BaseFloat mLambda; ///< exponential decay factor for q (observed probability of unit to be active)
      BaseFloat mSparsityCost; ///< sparsity cost coef.

      CuVector<BaseFloat> mSparsityQ;
      CuVector<BaseFloat> mSparsityQCurrent;
      CuVector<BaseFloat> mVisMean; ///< buffer for mean visible

  };




  ////////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuRbmSparse::
  inline 
  CuRbmSparse::
  CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred)
    : CuRbmBase(nInputs, nOutputs, pPred), 
      mVisHid(nInputs,nOutputs), 
      mVisBias(nInputs), mHidBias(nOutputs),
      mVisHidCorrection(nInputs,nOutputs), 
      mVisBiasCorrection(nInputs), mHidBiasCorrection(nOutputs),
      mBackpropErrBuf(),
      mVisType(BERNOULLI),
      mHidType(BERNOULLI),

      mSparsityPrior(0.0001),
      mLambda(0.95),
      mSparsityCost(1e-7),
      mSparsityQ(nOutputs),
      mSparsityQCurrent(nOutputs),
      mVisMean(nInputs)
  { 
    mVisHidCorrection.SetConst(0.0);
    mVisBiasCorrection.SetConst(0.0);
    mHidBiasCorrection.SetConst(0.0);

    mSparsityQ.SetConst(mSparsityPrior);
    mSparsityQCurrent.SetConst(0.0);
    mVisMean.SetConst(0.0);
  }


  inline
  CuRbmSparse::
  ~CuRbmSparse()
  { }

  inline CuComponent::ComponentType
  CuRbmSparse::
  GetType() const
  {
    return CuComponent::RBM_SPARSE;
  }

  inline const char*
  CuRbmSparse::
  GetName() const
  {
    return "<rbmsparse>";
  }



} //namespace



#endif