diff options
Diffstat (limited to 'src/CuTNetLib/cuSparseLinearity.cc')
-rw-r--r-- | src/CuTNetLib/cuSparseLinearity.cc | 190 |
1 files changed, 190 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuSparseLinearity.cc b/src/CuTNetLib/cuSparseLinearity.cc new file mode 100644 index 0000000..7209630 --- /dev/null +++ b/src/CuTNetLib/cuSparseLinearity.cc @@ -0,0 +1,190 @@ + + +#include "cuSparseLinearity.h" +#include <cmath> +#include <cstdlib> + + +namespace TNet +{ + + void + CuSparseLinearity:: + PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + Y.AddScaledRow(1.0,mBias,0.0); + Y.Gemm('N','N', 1.0, X, mLinearity, 1.0); + } + + + void + CuSparseLinearity:: + BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + Y.Gemm('N', 'T', 1.0, X, mLinearity, 0.0); + } + + + void + CuSparseLinearity:: + Update() + { + BaseFloat N = 1; + if(mGradDivFrm) { + N = static_cast<BaseFloat>(GetInput().Rows()); + } + BaseFloat mmt_gain = static_cast<BaseFloat>(1.0/(1.0-mMomentum)); + N *= mmt_gain; + + mLinearityCorrection.Gemm('T','N',1.0,GetInput(),GetErrorInput(),mMomentum); + mBiasCorrection.AddColSum(1.0,GetErrorInput(),mMomentum); + + mLinearity.AddScaled(-mLearningRate/N,mLinearityCorrection,1.0); + mBias.AddScaled(-mLearningRate/N,mBiasCorrection,1.0); + + mLinearityCorrectionAccu.AddScaled(1.0,mLinearityCorrection,1.0); + mLinearity.ApplyMask(mSparsityMask); + + //L1 regularization lasso... + //each update? everty 1000th update? + if(mL1Const > 0) { + BaseFloat L1_const = mLearningRate*mL1Const*(mGradDivFrm?1.0:GetInput().Rows()); + mLinearity.ApplyL1(L1_const); + } + + //L2 regularization weight decay (from actual weights only) + if(mWeightcost > 0) { + BaseFloat L2_decay = -mLearningRate*mWeightcost*(mGradDivFrm?1.0:GetInput().Rows()); + mLinearity.AddScaled(L2_decay, mLinearity,1.0); + } + + mNFrames += GetInput().Rows(); + + } + + + void + CuSparseLinearity:: + UpdateMask() + { + //move data to host + Matrix<BaseFloat> linearity, linearity_correction_accu; + Matrix<BaseFloat> sparsity_mask; + + mLinearity.CopyTo(linearity); + mLinearityCorrectionAccu.CopyTo(linearity_correction_accu); + mSparsityMask.CopyTo(sparsity_mask); + + //decide on new sparsity mask + for(size_t r=0; r<sparsity_mask.Rows(); r++) { + for(size_t c=0; c<sparsity_mask.Cols(); c++) { + if(sparsity_mask(r,c) == 1.0f) { //weight active + if(fabs(linearity(r,c)) < mSparsifyWeightThreshold) { + sparsity_mask(r,c) = 0;//deactivate + linearity(r,c) = 0; + } + } else { //weight inactive + if(abs(linearity_correction_accu(r,c))/(BaseFloat)mNFrames > mUnsparsifyAccu) { + sparsity_mask(r,c) = 1;//activate + } + } + } + } + + //move data to the device + mLinearity.CopyFrom(linearity); + mSparsityMask.CopyFrom(sparsity_mask); + } + + + void + CuSparseLinearity:: + ReadFromStream(std::istream& rIn) + { + //matrix is stored transposed as SNet does + BfMatrix transpose; + rIn >> transpose; + mLinearity.CopyFrom(BfMatrix(transpose, TRANS)); + //biases stored normally + BfVector bias; + rIn >> bias; + mBias.CopyFrom(bias); + + //sparsity mask + rIn >> std::ws; + Matrix<BaseFloat> mask_transp; + if(rIn.peek() == 'm') {//load from file + rIn >> mask_transp; + } else {//or set all elements active + mask_transp.Init(transpose.Rows(),transpose.Cols()); + int items=transpose.Rows()*transpose.Stride(); + BaseFloat* p = mask_transp.pData(); + for(int i=0; i<items; i++) {//set all elements to one + *p++ = 1; + } + } + mSparsityMask.CopyFrom(BfMatrix(mask_transp,TRANS)); + + //dummy matrix with acumulated gradients + rIn >> std::ws; + if(rIn.peek() == 'm') {//load from file + BfMatrix dummy; + rIn >> dummy; + } + + if(transpose.Cols()*transpose.Rows() == 0) { + Error("Missing linearity matrix in network file"); + } + if(bias.Dim() == 0) { + Error("Missing bias vector in network file"); + } + if(mLinearity.Cols() != GetNOutputs() || + mLinearity.Rows() != GetNInputs() || + mBias.Dim() != GetNOutputs() + ){ + std::ostringstream os; + os << "Wrong dimensionalities of matrix/vector in network file\n" + << "Inputs:" << GetNInputs() + << "Outputs:" << GetNOutputs() + << "\n" + << "linearityCols:" << mLinearity.Cols() + << "linearityRows:" << mLinearity.Rows() + << "biasDims:" << mBias.Dim() + << "\n"; + Error(os.str()); + } + + assert(mLinearity.Rows() == mSparsityMask.Rows()); + assert(mLinearity.Cols() == mSparsityMask.Cols()); + + } + + + void + CuSparseLinearity:: + WriteToStream(std::ostream& rOut) + { + UpdateMask(); + + //matrix is stored transposed as SNet does + BfMatrix tmp; + mLinearity.CopyTo(tmp); + BfMatrix transpose(tmp, TRANS); + rOut << transpose; + //biases stored normally + BfVector vec; + mBias.CopyTo(vec); + rOut << vec; + rOut << std::endl; + //store mask + mSparsityMask.CopyTo(tmp); + rOut << BfMatrix(tmp,TRANS); + //store accu + mLinearityCorrectionAccu.CopyTo(tmp); + rOut << BfMatrix(tmp,TRANS); + + } + + +} //namespace + |