summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuRbm.h
blob: c1e984b1cf7fd0f136a0f11e816fe583de16650d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#ifndef _CU_RBM_H_
#define _CU_RBM_H_


#include "cuComponent.h"
#include "cumatrix.h"


#include "Matrix.h"
#include "Vector.h"


namespace TNet {

  class CuRbmBase : public CuUpdatableComponent
  {
   public:
    typedef enum {
      BERNOULLI,
      GAUSSIAN
    } RbmUnitType;
   
    CuRbmBase(size_t nInputs, size_t nOutputs, CuComponent *pPred) :
      CuUpdatableComponent(nInputs, nOutputs, pPred)
    { }
   
    
    virtual void Propagate(
      const CuMatrix<BaseFloat>& visProbs, 
      CuMatrix<BaseFloat>& hidProbs
    ) = 0;
    virtual void Reconstruct(
      const CuMatrix<BaseFloat>& hidState, 
      CuMatrix<BaseFloat>& visProbs
    ) = 0;
    virtual void RbmUpdate(
      const CuMatrix<BaseFloat>& pos_vis, 
      const CuMatrix<BaseFloat>& pos_hid, 
      const CuMatrix<BaseFloat>& neg_vis, 
      const CuMatrix<BaseFloat>& neg_hid
    ) = 0;

    virtual RbmUnitType VisType() = 0;
    virtual RbmUnitType HidType() = 0;
  };


  class CuRbm : public CuRbmBase
  {
    public:

      CuRbm(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      ~CuRbm();  
      
      ComponentType GetType() const;
      const char* GetName() const;

      //CuUpdatableComponent API
      void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
      void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);

      void Update();

      //RBM training API
      void Propagate(const CuMatrix<BaseFloat>& visProbs, CuMatrix<BaseFloat>& hidProbs);
      void Reconstruct(const CuMatrix<BaseFloat>& hidState, CuMatrix<BaseFloat>& visProbs);
      void RbmUpdate(const CuMatrix<BaseFloat>& pos_vis, const CuMatrix<BaseFloat>& pos_hid, const CuMatrix<BaseFloat>& neg_vis, const CuMatrix<BaseFloat>& neg_hid);

      RbmUnitType VisType()
      { return mVisType; }

      RbmUnitType HidType()
      { return mHidType; }

      //static void BinarizeProbs(const CuMatrix<BaseFloat>& probs, CuMatrix<BaseFloat>& states);

      //I/O
      void ReadFromStream(std::istream& rIn);
      void WriteToStream(std::ostream& rOut);

    protected:
      CuMatrix<BaseFloat> mVisHid;  ///< Matrix with neuron weights
      CuVector<BaseFloat> mVisBias;       ///< Vector with biases
      CuVector<BaseFloat> mHidBias;       ///< Vector with biases

      CuMatrix<BaseFloat> mVisHidCorrection; ///< Matrix for linearity updates
      CuVector<BaseFloat> mVisBiasCorrection;      ///< Vector for bias updates
      CuVector<BaseFloat> mHidBiasCorrection;      ///< Vector for bias updates

      CuMatrix<BaseFloat> mBackpropErrBuf;

      RbmUnitType mVisType;
      RbmUnitType mHidType;

  };




  ////////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuRbm::
  inline 
  CuRbm::
  CuRbm(size_t nInputs, size_t nOutputs, CuComponent *pPred)
    : CuRbmBase(nInputs, nOutputs, pPred), 
      mVisHid(nInputs,nOutputs), 
      mVisBias(nInputs), mHidBias(nOutputs),
      mVisHidCorrection(nInputs,nOutputs), 
      mVisBiasCorrection(nInputs), mHidBiasCorrection(nOutputs),
      mBackpropErrBuf(),
      mVisType(BERNOULLI),
      mHidType(BERNOULLI)
  { 
    mVisHidCorrection.SetConst(0.0);
    mVisBiasCorrection.SetConst(0.0);
    mHidBiasCorrection.SetConst(0.0);
  }


  inline
  CuRbm::
  ~CuRbm()
  { }

  inline CuComponent::ComponentType
  CuRbm::
  GetType() const
  {
    return CuComponent::RBM;
  }

  inline const char*
  CuRbm::
  GetName() const
  {
    return "<rbm>";
  }



} //namespace



#endif