diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/TNetLib/SharedLinearity.cc | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/TNetLib/SharedLinearity.cc')
-rw-r--r-- | src/TNetLib/SharedLinearity.cc | 277 |
1 files changed, 277 insertions, 0 deletions
diff --git a/src/TNetLib/SharedLinearity.cc b/src/TNetLib/SharedLinearity.cc new file mode 100644 index 0000000..108212c --- /dev/null +++ b/src/TNetLib/SharedLinearity.cc @@ -0,0 +1,277 @@ + + +#include "SharedLinearity.h" +#include "cblas.h" + +namespace TNet { + +void +SharedLinearity:: +PropagateFnc(const Matrix<BaseFloat>& X, Matrix<BaseFloat>& Y) +{ + //precopy bias + for(int k=0; k<mNInstances; k++) { + for(size_t r=0; r<X.Rows(); r++) { + memcpy(Y.pRowData(r)+k*mpBias->Dim(),mpBias->pData(),mpBias->Dim()*sizeof(BaseFloat)); + } + } + + //multiply blockwise + for(int k=0; k<mNInstances; k++) { + SubMatrix<BaseFloat> xblock(X,0,X.Rows(),k*mpLinearity->Rows(),mpLinearity->Rows()); + SubMatrix<BaseFloat> yblock(Y,0,Y.Rows(),k*mpLinearity->Cols(),mpLinearity->Cols()); + yblock.BlasGemm(1.0,xblock,NO_TRANS,*mpLinearity,NO_TRANS,1.0); + } +} + + +void +SharedLinearity:: +BackpropagateFnc(const Matrix<BaseFloat>& X, Matrix<BaseFloat>& Y) +{ + for(int k=0; k<mNInstances; k++) { + SubMatrix<BaseFloat> xblock(X,0,X.Rows(),k*mpLinearity->Cols(),mpLinearity->Cols()); + SubMatrix<BaseFloat> yblock(Y,0,Y.Rows(),k*mpLinearity->Rows(),mpLinearity->Rows()); + yblock.BlasGemm(1.0,xblock,NO_TRANS,*mpLinearity,TRANS,1.0); + } +} + +#if 0 +void +SharedLinearity:: +AccuUpdate() +{ + BaseFloat N = 1; + /* + //Not part of the interface!!! + if(mGradDivFrm) { + N = static_cast<BaseFloat>(GetInput().Rows()); + } + */ + BaseFloat mmt_gain = static_cast<BaseFloat>(1.0/(1.0-mMomentum)); + N *= mmt_gain; //compensate higher gradient estimates due to momentum + + //compensate augmented dyn. range of gradient caused by multiple instances + N *= static_cast<BaseFloat>(mNInstances); + + const Matrix<BaseFloat>& X = GetInput().Data(); + const Matrix<BaseFloat>& E = GetErrorInput().Data(); + //get gradient of shared linearity + for(int k=0; k<mNInstances; k++) { + SubMatrix<BaseFloat> xblock(X,0,X.Rows(),k*mLinearity.Rows(),mLinearity.Rows()); + SubMatrix<BaseFloat> eblock(E,0,E.Rows(),k*mLinearity.Cols(),mLinearity.Cols()); + mLinearityCorrection.BlasGemm(1.0,xblock,TRANS,eblock,NO_TRANS,((k==0)?mMomentum:1.0f)); + } + + //get gradient of shared bias + mBiasCorrection.Scale(mMomentum); + for(int r=0; r<E.Rows(); r++) { + for(int c=0; c<E.Cols(); c++) { + mBiasCorrection[c%mBiasCorrection.Dim()] += E(r,c); + } + } + + //perform update + mLinearity.AddScaled(-mLearningRate/N,mLinearityCorrection); + mBias.AddScaled(-mLearningRate/N,mBiasCorrection); + + //regularization weight decay + mLinearity.AddScaled(-mLearningRate*mWeightcost,mLinearity); +} +#endif + +void +SharedLinearity:: +ReadFromStream(std::istream& rIn) +{ + //number of instances of shared weights in layer + rIn >> std::ws >> mNInstances; + if(mNInstances < 1) { + std::ostringstream os; + os << "Bad number of instances:" << mNInstances; + Error(os.str()); + } + if(GetNInputs() % mNInstances != 0 || GetNOutputs() % mNInstances != 0) { + std::ostringstream os; + os << "Number of Inputs/Outputs must be divisible by number of instances" + << " Inputs:" << GetNInputs() + << " Outputs" << GetNOutputs() + << " Intances:" << mNInstances; + Error(os.str()); + } + + //matrix is stored transposed as SNet does + BfMatrix transpose; + rIn >> transpose; + mLinearity = BfMatrix(transpose, TRANS); + //biases stored normally + rIn >> mBias; + + if(transpose.Cols()*transpose.Rows() == 0) { + Error("Missing linearity matrix in network file"); + } + if(mBias.Dim() == 0) { + Error("Missing bias vector in network file"); + } + + + if(mLinearity.Cols() != (GetNOutputs() / mNInstances) || + mLinearity.Rows() != (GetNInputs() / mNInstances) || + mBias.Dim() != (GetNOutputs() / mNInstances) + ){ + std::ostringstream os; + os << "Wrong dimensionalities of matrix/vector in network file\n" + << "Inputs:" << GetNInputs() + << " Outputs:" << GetNOutputs() + << "\n" + << "N-Instances:" << mNInstances + << "\n" + << "linearityCols:" << mLinearity.Cols() << "(" << mLinearity.Cols()*mNInstances << ")" + << " linearityRows:" << mLinearity.Rows() << "(" << mLinearity.Rows()*mNInstances << ")" + << " biasDims:" << mBias.Dim() << "(" << mBias.Dim()*mNInstances << ")" + << "\n"; + Error(os.str()); + } + + mLinearityCorrection.Init(mLinearity.Rows(),mLinearity.Cols()); + mBiasCorrection.Init(mBias.Dim()); +} + + +void +SharedLinearity:: +WriteToStream(std::ostream& rOut) +{ + rOut << mNInstances << std::endl; + //matrix is stored transposed as SNet does + BfMatrix transpose(mLinearity, TRANS); + rOut << transpose; + //biases stored normally + rOut << mBias; + rOut << std::endl; +} + + +void +SharedLinearity:: +Gradient() +{ + const Matrix<BaseFloat>& X = GetInput(); + const Matrix<BaseFloat>& E = GetErrorInput(); + //get gradient of shared linearity + for(int k=0; k<mNInstances; k++) { + SubMatrix<BaseFloat> xblock(X,0,X.Rows(),k*mpLinearity->Rows(),mpLinearity->Rows()); + SubMatrix<BaseFloat> eblock(E,0,E.Rows(),k*mpLinearity->Cols(),mpLinearity->Cols()); + mLinearityCorrection.BlasGemm(1.0,xblock,TRANS,eblock,NO_TRANS,((k==0)?0.0f:1.0f)); + } + + //get gradient of shared bias + mBiasCorrection.Set(0.0f); + for(int r=0; r<E.Rows(); r++) { + for(int c=0; c<E.Cols(); c++) { + mBiasCorrection[c%mBiasCorrection.Dim()] += E(r,c); + } + } +} + + +void +SharedLinearity:: +AccuGradient(const UpdatableComponent& src, int thr, int thrN) +{ + //cast the argument + const SharedLinearity& src_comp = dynamic_cast<const SharedLinearity&>(src); + + //allocate accumulators when needed + if(mLinearityCorrectionAccu.MSize() == 0) { + mLinearityCorrectionAccu.Init(mpLinearity->Rows(),mpLinearity->Cols()); + } + if(mBiasCorrectionAccu.MSize() == 0) { + mBiasCorrectionAccu.Init(mpBias->Dim()); + } + + + //assert the dimensions + /* + assert(mLinearityCorrection.Rows() == src_comp.mLinearityCorrection.Rows()); + assert(mLinearityCorrection.Cols() == src_comp.mLinearityCorrection.Cols()); + assert(mBiasCorrection.Dim() == src_comp.mBiasCorrection.Dim()); + */ + + //need to find out which rows to sum... + int div = mLinearityCorrection.Rows() / thrN; + int mod = mLinearityCorrection.Rows() % thrN; + + int origin = thr * div + ((mod > thr)? thr : mod); + int rows = div + ((mod > thr)? 1 : 0); + + //std::cout << "[S" << thr << "," << origin << "," << rows << "]" << std::flush; + + //create the matrix windows + const SubMatrix<BaseFloat> src_mat ( + src_comp.mLinearityCorrection, + origin, rows, + 0, mLinearityCorrection.Cols() + ); + SubMatrix<double> tgt_mat ( + mLinearityCorrectionAccu, + origin, rows, + 0, mLinearityCorrection.Cols() + ); + //sum the rows + Add(tgt_mat,src_mat); + + //first thread will always sum the bias correction and adds frame count + if(thr == 0) { + //std::cout << "[BS" << thr << "]" << std::flush; + Add(mBiasCorrectionAccu,src_comp.mBiasCorrection); + } +} + + +void +SharedLinearity:: +Update(int thr, int thrN) +{ + //need to find out which rows to sum... + int div = mLinearity.Rows() / thrN; + int mod = mLinearity.Rows() % thrN; + + int origin = thr * div + ((mod > thr)? thr : mod); + int rows = div + ((mod > thr)? 1 : 0); + + //std::cout << "[P" << thr << "," << origin << "," << rows << "]" << std::flush; + + //get the matrix windows + SubMatrix<double> src_mat ( + mLinearityCorrectionAccu, + origin, rows, + 0, mLinearityCorrection.Cols() + ); + SubMatrix<BaseFloat> tgt_mat ( + mLinearity, + origin, rows, + 0, mLinearityCorrection.Cols() + ); + + //TODO perform L2 regularization + //tgt_mat.AddScaled(tgt_mat, -mWeightcost * num_frames); + + //update weights + AddScaled(tgt_mat, src_mat, -mLearningRate/static_cast<BaseFloat>(mNInstances)); + + //first thread always update bias + if(thr == 0) { + //std::cout << "[" << thr << "BP]" << std::flush; + AddScaled(mBias, mBiasCorrectionAccu, -mLearningRate/static_cast<BaseFloat>(mNInstances)); + } + + //reset the accumulators + src_mat.Zero(); + if(thr == 0) { + mBiasCorrectionAccu.Zero(); + } +} + + +} //namespace |