#include "cuSharedLinearity.h" #include "cumath.h" namespace TNet { void CuSharedLinearity:: PropagateFnc(const CuMatrix& X, CuMatrix& Y) { CuMath::VecExpand(mBias,mBiasExpand); /// [ 1 2 3 ] -> [ 1 2 3 1 2 3 ... ] Y.AddScaledRow(1.0,mBiasExpand,0.0); //mBiasExpand.Print(); for(int i=0; i::OffsetGemm('N','N', 1.0, X, mLinearity, 1.0, Y, i*mLinearity.Rows(), 0, i*mLinearity.Cols()); } //std::cout << CuDevice::Instantiate().GetFreeMemory(); //GetInput().Print(); //GetOutput().Print(); } void CuSharedLinearity:: BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) { for(int i=0; i::OffsetGemm('N', 'T', 1.0, X, mLinearity, 0.0, Y, i*mLinearity.Cols(), 0, i*mLinearity.Rows()); } } void CuSharedLinearity:: Update() { #if 0 //former implementation BaseFloat N = static_cast(GetInput().Rows()); for(int i=0; i::OffsetGemm('T','N',-mLearningRate/(N*mNInstances), GetInput(),GetErrorInput(), ((i==0)?mMomentum:1.0f), mLinearityCorrection, i*mLinearity.Rows(),i*mLinearity.Cols(),0); } mBiasCorrectionExpand.AddColSum(1.0,GetErrorInput(),0.0); CuMath::VecAddColSum(-mLearningRate/(N*mNInstances),mBiasCorrectionExpand,mMomentum,mBiasCorrection); //regularization weight decay mLinearityCorrection.AddScaled(-mLearningRate*mWeightcost,mLinearity,1.0); mLinearity.AddScaled(1.0,mLinearityCorrection,1.0); mBias.AddScaled(1.0,mBiasCorrection,1.0); #endif #if 1 //new implementation BaseFloat N = 1; if(mGradDivFrm) { N = static_cast(GetInput().Rows()); } BaseFloat mmt_gain = static_cast(1.0/(1.0-mMomentum)); N *= mmt_gain; //compensate higher gradient estimates due to momentum //compensate augmented dyn. range of gradient caused by multiple instances N *= static_cast(mNInstances); //get gradient of shared linearity for(int i=0; i::OffsetGemm('T','N',1.0, GetInput(),GetErrorInput(), ((i==0)?mMomentum:1.0f), mLinearityCorrection, i*mLinearity.Rows(),i*mLinearity.Cols(),0); } //get gradient of shared bias mBiasCorrectionExpand.AddColSum(1.0,GetErrorInput(),0.0); CuMath::VecAddColSum(1.0,mBiasCorrectionExpand,mMomentum,mBiasCorrection); //perform update mLinearity.AddScaled(-mLearningRate/N,mLinearityCorrection,1.0); mBias.AddScaled(-mLearningRate/N,mBiasCorrection,1.0); //regularization weight decay mLinearity.AddScaled(-mLearningRate*mWeightcost,mLinearity,1.0); #endif } void CuSharedLinearity:: ReadFromStream(std::istream& rIn) { //number of instances of shared weights in layer rIn >> std::ws >> mNInstances; if(mNInstances < 1) { std::ostringstream os; os << "Bad number of instances:" << mNInstances; Error(os.str()); } if(GetNInputs() % mNInstances != 0 || GetNOutputs() % mNInstances != 0) { std::ostringstream os; os << "Number of Inputs/Outputs must be divisible by number of instances" << " Inputs:" << GetNInputs() << " Outputs" << GetNOutputs() << " Intances:" << mNInstances; Error(os.str()); } //matrix is stored transposed as SNet does BfMatrix transpose; rIn >> transpose; mLinearity.CopyFrom(BfMatrix(transpose, TRANS)); //biases stored normally BfVector bias; rIn >> bias; mBias.CopyFrom(bias); if(transpose.Cols()*transpose.Rows() == 0) { Error("Missing linearity matrix in network file"); } if(bias.Dim() == 0) { Error("Missing bias vector in network file"); } if(mLinearity.Cols() != GetNOutputs() / mNInstances || mLinearity.Rows() != GetNInputs() / mNInstances || mBias.Dim() != GetNOutputs() / mNInstances ){ std::ostringstream os; os << "Wrong dimensionalities of matrix/vector in network file\n" << "Inputs:" << GetNInputs() << "Outputs:" << GetNOutputs() << "\n" << "linearityCols:" << mLinearity.Cols() << "linearityRows:" << mLinearity.Rows() << "biasDims:" << mBias.Dim() << "\n"; Error(os.str()); } mLinearityCorrection.Init(mLinearity.Rows(),mLinearity.Cols()); mBiasCorrection.Init(mBias.Dim()); mBiasExpand.Init(mBias.Dim()*mNInstances); mBiasCorrectionExpand.Init(mBias.Dim()*mNInstances); } void CuSharedLinearity:: WriteToStream(std::ostream& rOut) { rOut << mNInstances << std::endl; //matrix is stored transposed as SNet does BfMatrix tmp; mLinearity.CopyTo(tmp); BfMatrix transpose(tmp, TRANS); rOut << transpose; //biases stored normally BfVector vec; mBias.CopyTo(vec); rOut << vec; rOut << std::endl; } } //namespace