summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuBiasedLinearity.cc
blob: 830e03e001a594c2098d2224994c4403bf629c1a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


#include "cuBiasedLinearity.h"


namespace TNet
{

  void 
  CuBiasedLinearity::
  PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    //Y.SetConst(0.0);
    Y.AddScaledRow(1.0,mBias,0.0);
    Y.Gemm('N','N', 1.0, X, mLinearity, 1.0);
  }


  void 
  CuBiasedLinearity::
  BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    //Y.SetConst(0.0);
    std::cout<<mLinearity.Rows()<<"\t"<<mLinearity.Cols()<<std::endl;
    Y.Gemm('N', 'T', 1.0, X, mLinearity, 0.0);
  }

  
  void 
  CuBiasedLinearity::
  Update() 
  {
#if 0
    //former implementation
    BaseFloat N = static_cast<BaseFloat>(GetInput().Rows());

    mLinearityCorrection.Gemm('T','N',-mLearningRate/N,GetInput(),GetErrorInput(),mMomentum);
    mBiasCorrection.AddColSum(-mLearningRate/N,GetErrorInput(),mMomentum);

    //regularization weight decay
    mLinearityCorrection.AddScaled(-mLearningRate*mWeightcost,mLinearity,1.0);
    
    mLinearity.AddScaled(1.0,mLinearityCorrection,1.0);
    mBias.AddScaled(1.0,mBiasCorrection,1.0);
#endif

#if 1
    //new implementation
    BaseFloat N = 1;
    if(mGradDivFrm) {
      N = static_cast<BaseFloat>(GetInput().Rows());
    }
    BaseFloat mmt_gain = static_cast<BaseFloat>(1.0/(1.0-mMomentum));
    N *= mmt_gain;

    mLinearityCorrection.Gemm('T','N',1.0,GetInput(),GetErrorInput(),mMomentum);
    mBiasCorrection.AddColSum(1.0,GetErrorInput(),mMomentum);

    mLinearity.AddScaled(-mLearningRate/N,mLinearityCorrection,1.0);
    mBias.AddScaled(-mLearningRate/N,mBiasCorrection,1.0);

    //regularization weight decay (from actual weights only)
    BaseFloat L2_decay = -mLearningRate*mWeightcost*(mGradDivFrm?1.0:GetInput().Rows());
    mLinearity.AddScaled(L2_decay, mLinearity,1.0);
#endif
  }


  void
  CuBiasedLinearity::
  ReadFromStream(std::istream& rIn)
  {
    //matrix is stored transposed as SNet does
    BfMatrix transpose;
    rIn >> transpose;
    mLinearity.CopyFrom(BfMatrix(transpose, TRANS));
    //biases stored normally
    BfVector bias;
    rIn >> bias;
    mBias.CopyFrom(bias);

    if(transpose.Cols()*transpose.Rows() == 0) {
      Error("Missing linearity matrix in network file");
    }
    if(bias.Dim() == 0) {
      Error("Missing bias vector in network file");
    }
    if(mLinearity.Cols() != GetNOutputs() || 
       mLinearity.Rows() != GetNInputs() ||
       mBias.Dim() != GetNOutputs()
    ){
      std::ostringstream os;
      os << "Wrong dimensionalities of matrix/vector in network file\n"
         << "Inputs:" << GetNInputs()
         << "Outputs:" << GetNOutputs()
         << "\n"
         << "linearityCols:" << mLinearity.Cols()
         << "linearityRows:" << mLinearity.Rows()
         << "biasDims:" << mBias.Dim()
         << "\n";
      Error(os.str());
    }
  }

   
  void
  CuBiasedLinearity::
  WriteToStream(std::ostream& rOut)
  {
    //matrix is stored transposed as SNet does
    BfMatrix tmp;
    mLinearity.CopyTo(tmp);
    BfMatrix transpose(tmp, TRANS);
    rOut << transpose;
    //biases stored normally
    BfVector vec;
    mBias.CopyTo(vec);
    rOut << vec;
    rOut << std::endl;
  }

 
} //namespace