1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
|
#ifndef _CU_RBM_SPARSE_H_
#define _CU_RBM_SPARSE_H_
#include "cuComponent.h"
#include "cumatrix.h"
#include "cuRbm.h"
#include "Matrix.h"
#include "Vector.h"
namespace TNet {
class CuRbmSparse : public CuRbmBase
{
public:
CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred);
~CuRbmSparse();
ComponentType GetType() const;
const char* GetName() const;
//CuUpdatableComponent API
void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
void Update();
//RBM training API
void Propagate(const CuMatrix<BaseFloat>& visProbs, CuMatrix<BaseFloat>& hidProbs);
void Reconstruct(const CuMatrix<BaseFloat>& hidState, CuMatrix<BaseFloat>& visProbs);
void RbmUpdate(const CuMatrix<BaseFloat>& pos_vis, const CuMatrix<BaseFloat>& pos_hid, const CuMatrix<BaseFloat>& neg_vis, const CuMatrix<BaseFloat>& neg_hid);
RbmUnitType VisType()
{ return mVisType; }
RbmUnitType HidType()
{ return mHidType; }
//static void BinarizeProbs(const CuMatrix<BaseFloat>& probs, CuMatrix<BaseFloat>& states);
//I/O
void ReadFromStream(std::istream& rIn);
void WriteToStream(std::ostream& rOut);
protected:
CuMatrix<BaseFloat> mVisHid; ///< Matrix with neuron weights
CuVector<BaseFloat> mVisBias; ///< Vector with biases
CuVector<BaseFloat> mHidBias; ///< Vector with biases
CuMatrix<BaseFloat> mVisHidCorrection; ///< Matrix for linearity updates
CuVector<BaseFloat> mVisBiasCorrection; ///< Vector for bias updates
CuVector<BaseFloat> mHidBiasCorrection; ///< Vector for bias updates
CuMatrix<BaseFloat> mBackpropErrBuf;
RbmUnitType mVisType;
RbmUnitType mHidType;
////// sparsity
BaseFloat mSparsityPrior; ///< sparsity target (unit activity prior)
BaseFloat mLambda; ///< exponential decay factor for q (observed probability of unit to be active)
BaseFloat mSparsityCost; ///< sparsity cost coef.
CuVector<BaseFloat> mSparsityQ;
CuVector<BaseFloat> mSparsityQCurrent;
CuVector<BaseFloat> mVisMean; ///< buffer for mean visible
};
////////////////////////////////////////////////////////////////////////////
// INLINE FUNCTIONS
// CuRbmSparse::
inline
CuRbmSparse::
CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred)
: CuRbmBase(nInputs, nOutputs, pPred),
mVisHid(nInputs,nOutputs),
mVisBias(nInputs), mHidBias(nOutputs),
mVisHidCorrection(nInputs,nOutputs),
mVisBiasCorrection(nInputs), mHidBiasCorrection(nOutputs),
mBackpropErrBuf(),
mVisType(BERNOULLI),
mHidType(BERNOULLI),
mSparsityPrior(0.0001),
mLambda(0.95),
mSparsityCost(1e-7),
mSparsityQ(nOutputs),
mSparsityQCurrent(nOutputs),
mVisMean(nInputs)
{
mVisHidCorrection.SetConst(0.0);
mVisBiasCorrection.SetConst(0.0);
mHidBiasCorrection.SetConst(0.0);
mSparsityQ.SetConst(mSparsityPrior);
mSparsityQCurrent.SetConst(0.0);
mVisMean.SetConst(0.0);
}
inline
CuRbmSparse::
~CuRbmSparse()
{ }
inline CuComponent::ComponentType
CuRbmSparse::
GetType() const
{
return CuComponent::RBM_SPARSE;
}
inline const char*
CuRbmSparse::
GetName() const
{
return "<rbmsparse>";
}
} //namespace
#endif
|