From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- src/CuTNetLib/cuLinearity.cc | 107 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 src/CuTNetLib/cuLinearity.cc (limited to 'src/CuTNetLib/cuLinearity.cc') diff --git a/src/CuTNetLib/cuLinearity.cc b/src/CuTNetLib/cuLinearity.cc new file mode 100644 index 0000000..5fb247d --- /dev/null +++ b/src/CuTNetLib/cuLinearity.cc @@ -0,0 +1,107 @@ + + +#include "cuLinearity.h" + + +namespace TNet +{ + + void + CuLinearity:: + PropagateFnc(const CuMatrix& X, CuMatrix& Y) + { + //Y.SetConst(0.0); + Y.Gemm('N','N', 1.0, X, mLinearity, 0.0); + } + + + void + CuLinearity:: + BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) + { + //Y.SetConst(0.0); + Y.Gemm('N', 'T', 1.0, X, mLinearity, 0.0); + } + + + void + CuLinearity:: + Update() + { +#if 0 + //former implementation + BaseFloat N = static_cast(GetInput().Rows()); + + mLinearityCorrection.Gemm('T','N',-mLearningRate/N,GetInput(),GetErrorInput(),mMomentum); + mBiasCorrection.AddColSum(-mLearningRate/N,GetErrorInput(),mMomentum); + + //regularization weight decay + mLinearityCorrection.AddScaled(-mLearningRate*mWeightcost,mLinearity,1.0); + + mLinearity.AddScaled(1.0,mLinearityCorrection,1.0); + mBias.AddScaled(1.0,mBiasCorrection,1.0); +#endif + +#if 1 + //new implementation + BaseFloat N = 1; + if(mGradDivFrm) { + N = static_cast(GetInput().Rows()); + } + BaseFloat mmt_gain = static_cast(1.0/(1.0-mMomentum)); + N *= mmt_gain; + + mLinearityCorrection.Gemm('T','N',1.0,GetInput(),GetErrorInput(),mMomentum); + + mLinearity.AddScaled(-mLearningRate/N,mLinearityCorrection,1.0); + + //regularization weight decay (from actual weights only) + BaseFloat L2_decay = -mLearningRate*mWeightcost*(mGradDivFrm?1.0:GetInput().Rows()); + mLinearity.AddScaled(L2_decay, mLinearity,1.0); +#endif + } + + + void + CuLinearity:: + ReadFromStream(std::istream& rIn) + { + //matrix is stored transposed as SNet does + BfMatrix transpose; + rIn >> transpose; + mLinearity.CopyFrom(BfMatrix(transpose, TRANS)); + + /*if(transpose.Cols()*transpose.Rows() == 0) { + Error("Missing linearity matrix in network file"); + }*/ + if(mLinearity.Cols() != GetNOutputs() || + mLinearity.Rows() != GetNInputs() + ){ + std::ostringstream os; + os << "Wrong dimensionalities of matrix/vector in network file\n" + << "Inputs:" << GetNInputs() + << "Outputs:" << GetNOutputs() + << "\n" + << "linearityCols:" << mLinearity.Cols() + << "linearityRows:" << mLinearity.Rows() + << "\n"; + Error(os.str()); + } + } + + + void + CuLinearity:: + WriteToStream(std::ostream& rOut) + { + //matrix is stored transposed as SNet does + BfMatrix tmp; + mLinearity.CopyTo(tmp); + BfMatrix transpose(tmp, TRANS); + rOut << transpose; + rOut << std::endl; + } + + +} //namespace + -- cgit v1.2.3-70-g09d2