#include #include #include "cuRbm.h" #include "cumath.h" namespace TNet { void CuRbm:: PropagateFnc(const CuMatrix& X, CuMatrix& Y) { Y.SetConst(0.0); Y.AddScaledRow(1.0,mHidBias,0.0); Y.Gemm('N','N', 1.0, X, mVisHid, 1.0); if(mHidType == BERNOULLI) { CuMath::Sigmoid(Y,Y); } } void CuRbm:: BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) { if(mHidType == BERNOULLI) { mBackpropErrBuf.Init(X.Rows(),X.Cols()); CuMath::DiffSigmoid(mBackpropErrBuf,X,GetOutput()); } else { mBackpropErrBuf.CopyFrom(X); } Y.SetConst(0.0); Y.Gemm('N', 'T', 1.0, mBackpropErrBuf, mVisHid, 0.0); } void CuRbm:: Update() { //THIS IS DONE TWICE BECAUSE OF THE BACKPROP STOPPER!!! if(mHidType == BERNOULLI) { mBackpropErrBuf.Init(GetErrorInput().Rows(),GetErrorInput().Cols()); CuMath::DiffSigmoid(mBackpropErrBuf,GetErrorInput(),GetOutput()); } else { mBackpropErrBuf.CopyFrom(GetErrorInput()); } /* std::cout << " " << GetInput().Rows() << " " << GetInput().Cols() << " " << mBackpropErrBuf.Rows() << " " << mBackpropErrBuf.Cols() << " " << mVisHidCorrection.Rows() << " " << mVisHidCorrection.Cols() ; */ #if 0 //former implementation BaseFloat N = static_cast(GetInput().Rows()); mVisHidCorrection.Gemm('T','N',-mLearningRate/N,GetInput(),mBackpropErrBuf,mMomentum); mHidBiasCorrection.AddColSum(-mLearningRate/N,mBackpropErrBuf,mMomentum); //regularization weight decay mVisHidCorrection.AddScaled(-mLearningRate*mWeightcost,mVisHid,1.0); mVisHid.AddScaled(1.0,mVisHidCorrection,1.0); mHidBias.AddScaled(1.0,mHidBiasCorrection,1.0); #endif #if 1 //new implementation BaseFloat N = 1; if(mGradDivFrm) { N = static_cast(GetInput().Rows()); } BaseFloat mmt_gain = static_cast(1.0/(1.0-mMomentum)); N *= mmt_gain; mVisHidCorrection.Gemm('T','N',1.0,GetInput(),mBackpropErrBuf,mMomentum); mHidBiasCorrection.AddColSum(1.0,mBackpropErrBuf,mMomentum); mVisHid.AddScaled(-mLearningRate/N,mVisHidCorrection,1.0); mHidBias.AddScaled(-mLearningRate/N,mHidBiasCorrection,1.0); //regularization weight decay (from actual weights only) mVisHid.AddScaled(-mLearningRate*mWeightcost,mVisHid,1.0); #endif } void CuRbm:: Propagate(const CuMatrix& visProbs, CuMatrix& hidProbs) { if(visProbs.Cols() != GetNInputs()) { std::ostringstream os; os << " Nonmatching input dim, needs:" << GetNInputs() << " got:" << visProbs.Cols() << "\n"; Error(os.str()); } hidProbs.Init(visProbs.Rows(),GetNOutputs()); PropagateFnc(visProbs, hidProbs); } void CuRbm:: Reconstruct(const CuMatrix& hidState, CuMatrix& visProbs) { visProbs.Init(hidState.Rows(),mNInputs); visProbs.SetConst(0.0); visProbs.AddScaledRow(1.0,mVisBias,0.0); visProbs.Gemm('N','T', 1.0, hidState, mVisHid, 1.0); if(mVisType == BERNOULLI) { CuMath::Sigmoid(visProbs,visProbs); } } void CuRbm:: RbmUpdate(const CuMatrix& pos_vis, const CuMatrix& pos_hid, const CuMatrix& neg_vis, const CuMatrix& neg_hid) { assert(pos_vis.Rows() == pos_hid.Rows() && pos_vis.Rows() == neg_vis.Rows() && pos_vis.Rows() == neg_hid.Rows() && pos_vis.Cols() == neg_vis.Cols() && pos_hid.Cols() == neg_hid.Cols() && pos_vis.Cols() == mNInputs && pos_hid.Cols() == mNOutputs); // UPDATE vishid matrix // // vishidinc = momentum*vishidinc + ... // epsilonw*( (posprods-negprods)/numcases - weightcost*vishid); // // vishidinc[t] = -(epsilonw/numcases)*negprods + momentum*vishidinc[t-1] // +(epsilonw/numcases)*posprods // -(epsilonw*weightcost)*vishid[t-1] // BaseFloat N = static_cast(pos_vis.Rows()); mVisHidCorrection.Gemm('T','N',-mLearningRate/N,neg_vis,neg_hid,mMomentum); mVisHidCorrection.Gemm('T','N',+mLearningRate/N,pos_vis,pos_hid,1.0); mVisHidCorrection.AddScaled(-mLearningRate*mWeightcost,mVisHid,1.0); mVisHid.AddScaled(1.0,mVisHidCorrection,1.0); // UPDATE visbias vector // // visbiasinc = momentum*visbiasinc + (epsilonvb/numcases)*(posvisact-negvisact); // mVisBiasCorrection.AddColSum(-mLearningRate/N,neg_vis,mMomentum); mVisBiasCorrection.AddColSum(+mLearningRate/N,pos_vis,1.0); mVisBias.AddScaled(1.0,mVisBiasCorrection,1.0); // UPDATE hidbias vector // // hidbiasinc = momentum*hidbiasinc + (epsilonhb/numcases)*(poshidact-neghidact); // mHidBiasCorrection.AddColSum(-mLearningRate/N,neg_hid,mMomentum); mHidBiasCorrection.AddColSum(+mLearningRate/N,pos_hid,1.0); mHidBias.AddScaled(1.0/*0.0*/,mHidBiasCorrection,1.0); } void CuRbm:: ReadFromStream(std::istream& rIn) { //type of the units std::string str; rIn >> std::ws >> str; if(0 == str.compare("bern")) { mVisType = BERNOULLI; } else if(0 == str.compare("gauss")) { mVisType = GAUSSIAN; } else Error(std::string("Invalid unit type: ")+str); rIn >> std::ws >> str; if(0 == str.compare("bern")) { mHidType = BERNOULLI; } else if(0 == str.compare("gauss")) { mHidType = GAUSSIAN; } else Error(std::string("Invalid unit type: ")+str); //matrix is stored transposed as SNet does BfMatrix transpose; rIn >> transpose; mVisHid.CopyFrom(BfMatrix(transpose, TRANS)); //biases stored normally BfVector bias; rIn >> bias; mVisBias.CopyFrom(bias); rIn >> bias; mHidBias.CopyFrom(bias); } void CuRbm:: WriteToStream(std::ostream& rOut) { //store unit type info if(mVisType == BERNOULLI) { rOut << " bern "; } else { rOut << " gauss "; } if(mHidType == BERNOULLI) { rOut << " bern\n"; } else { rOut << " gauss\n"; } //matrix is stored transposed as SNet does BfMatrix tmp; mVisHid.CopyTo(tmp); BfMatrix transpose(tmp, TRANS); rOut << transpose; //biases stored normally BfVector vec; mVisBias.CopyTo(vec); rOut << vec; rOut << std::endl; mHidBias.CopyTo(vec); rOut << vec; rOut << std::endl; } } //namespace