summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuSparseLinearity.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/CuTNetLib/cuSparseLinearity.cc')
-rw-r--r--src/CuTNetLib/cuSparseLinearity.cc190
1 files changed, 190 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuSparseLinearity.cc b/src/CuTNetLib/cuSparseLinearity.cc
new file mode 100644
index 0000000..7209630
--- /dev/null
+++ b/src/CuTNetLib/cuSparseLinearity.cc
@@ -0,0 +1,190 @@
+
+
+#include "cuSparseLinearity.h"
+#include <cmath>
+#include <cstdlib>
+
+
+namespace TNet
+{
+
+ void
+ CuSparseLinearity::
+ PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ Y.AddScaledRow(1.0,mBias,0.0);
+ Y.Gemm('N','N', 1.0, X, mLinearity, 1.0);
+ }
+
+
+ void
+ CuSparseLinearity::
+ BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ Y.Gemm('N', 'T', 1.0, X, mLinearity, 0.0);
+ }
+
+
+ void
+ CuSparseLinearity::
+ Update()
+ {
+ BaseFloat N = 1;
+ if(mGradDivFrm) {
+ N = static_cast<BaseFloat>(GetInput().Rows());
+ }
+ BaseFloat mmt_gain = static_cast<BaseFloat>(1.0/(1.0-mMomentum));
+ N *= mmt_gain;
+
+ mLinearityCorrection.Gemm('T','N',1.0,GetInput(),GetErrorInput(),mMomentum);
+ mBiasCorrection.AddColSum(1.0,GetErrorInput(),mMomentum);
+
+ mLinearity.AddScaled(-mLearningRate/N,mLinearityCorrection,1.0);
+ mBias.AddScaled(-mLearningRate/N,mBiasCorrection,1.0);
+
+ mLinearityCorrectionAccu.AddScaled(1.0,mLinearityCorrection,1.0);
+ mLinearity.ApplyMask(mSparsityMask);
+
+ //L1 regularization lasso...
+ //each update? everty 1000th update?
+ if(mL1Const > 0) {
+ BaseFloat L1_const = mLearningRate*mL1Const*(mGradDivFrm?1.0:GetInput().Rows());
+ mLinearity.ApplyL1(L1_const);
+ }
+
+ //L2 regularization weight decay (from actual weights only)
+ if(mWeightcost > 0) {
+ BaseFloat L2_decay = -mLearningRate*mWeightcost*(mGradDivFrm?1.0:GetInput().Rows());
+ mLinearity.AddScaled(L2_decay, mLinearity,1.0);
+ }
+
+ mNFrames += GetInput().Rows();
+
+ }
+
+
+ void
+ CuSparseLinearity::
+ UpdateMask()
+ {
+ //move data to host
+ Matrix<BaseFloat> linearity, linearity_correction_accu;
+ Matrix<BaseFloat> sparsity_mask;
+
+ mLinearity.CopyTo(linearity);
+ mLinearityCorrectionAccu.CopyTo(linearity_correction_accu);
+ mSparsityMask.CopyTo(sparsity_mask);
+
+ //decide on new sparsity mask
+ for(size_t r=0; r<sparsity_mask.Rows(); r++) {
+ for(size_t c=0; c<sparsity_mask.Cols(); c++) {
+ if(sparsity_mask(r,c) == 1.0f) { //weight active
+ if(fabs(linearity(r,c)) < mSparsifyWeightThreshold) {
+ sparsity_mask(r,c) = 0;//deactivate
+ linearity(r,c) = 0;
+ }
+ } else { //weight inactive
+ if(abs(linearity_correction_accu(r,c))/(BaseFloat)mNFrames > mUnsparsifyAccu) {
+ sparsity_mask(r,c) = 1;//activate
+ }
+ }
+ }
+ }
+
+ //move data to the device
+ mLinearity.CopyFrom(linearity);
+ mSparsityMask.CopyFrom(sparsity_mask);
+ }
+
+
+ void
+ CuSparseLinearity::
+ ReadFromStream(std::istream& rIn)
+ {
+ //matrix is stored transposed as SNet does
+ BfMatrix transpose;
+ rIn >> transpose;
+ mLinearity.CopyFrom(BfMatrix(transpose, TRANS));
+ //biases stored normally
+ BfVector bias;
+ rIn >> bias;
+ mBias.CopyFrom(bias);
+
+ //sparsity mask
+ rIn >> std::ws;
+ Matrix<BaseFloat> mask_transp;
+ if(rIn.peek() == 'm') {//load from file
+ rIn >> mask_transp;
+ } else {//or set all elements active
+ mask_transp.Init(transpose.Rows(),transpose.Cols());
+ int items=transpose.Rows()*transpose.Stride();
+ BaseFloat* p = mask_transp.pData();
+ for(int i=0; i<items; i++) {//set all elements to one
+ *p++ = 1;
+ }
+ }
+ mSparsityMask.CopyFrom(BfMatrix(mask_transp,TRANS));
+
+ //dummy matrix with acumulated gradients
+ rIn >> std::ws;
+ if(rIn.peek() == 'm') {//load from file
+ BfMatrix dummy;
+ rIn >> dummy;
+ }
+
+ if(transpose.Cols()*transpose.Rows() == 0) {
+ Error("Missing linearity matrix in network file");
+ }
+ if(bias.Dim() == 0) {
+ Error("Missing bias vector in network file");
+ }
+ if(mLinearity.Cols() != GetNOutputs() ||
+ mLinearity.Rows() != GetNInputs() ||
+ mBias.Dim() != GetNOutputs()
+ ){
+ std::ostringstream os;
+ os << "Wrong dimensionalities of matrix/vector in network file\n"
+ << "Inputs:" << GetNInputs()
+ << "Outputs:" << GetNOutputs()
+ << "\n"
+ << "linearityCols:" << mLinearity.Cols()
+ << "linearityRows:" << mLinearity.Rows()
+ << "biasDims:" << mBias.Dim()
+ << "\n";
+ Error(os.str());
+ }
+
+ assert(mLinearity.Rows() == mSparsityMask.Rows());
+ assert(mLinearity.Cols() == mSparsityMask.Cols());
+
+ }
+
+
+ void
+ CuSparseLinearity::
+ WriteToStream(std::ostream& rOut)
+ {
+ UpdateMask();
+
+ //matrix is stored transposed as SNet does
+ BfMatrix tmp;
+ mLinearity.CopyTo(tmp);
+ BfMatrix transpose(tmp, TRANS);
+ rOut << transpose;
+ //biases stored normally
+ BfVector vec;
+ mBias.CopyTo(vec);
+ rOut << vec;
+ rOut << std::endl;
+ //store mask
+ mSparsityMask.CopyTo(tmp);
+ rOut << BfMatrix(tmp,TRANS);
+ //store accu
+ mLinearityCorrectionAccu.CopyTo(tmp);
+ rOut << BfMatrix(tmp,TRANS);
+
+ }
+
+
+} //namespace
+