summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuUpdatableBias.cc
blob: 2a9cbed0b516b499b25d7f9702db00d5da14009d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


#include "cuUpdatableBias.h"


namespace TNet
{

  void 
  CuUpdatableBias::
  PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    //Y.SetConst(0.0);
    Y.AddScaledRow(1.0,mBias,0.0);
    Y.AddScaled(1.0,X,1.0);
  }


  void 
  CuUpdatableBias::
  BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    //Y.SetConst(0.0);
    Y.CopyFrom(X);
  }

  
  void 
  CuUpdatableBias::
  Update() 
  {
#if 0
    //former implementation
    BaseFloat N = static_cast<BaseFloat>(GetInput().Rows());

    mBiasCorrection.AddColSum(-mLearningRate/N,GetErrorInput(),mMomentum);

    mBias.AddScaled(1.0,mBiasCorrection,1.0);
#endif

#if 1
    //new implementation
    BaseFloat N = 1;
    if(mGradDivFrm) {
      N = static_cast<BaseFloat>(GetInput().Rows());
    }
    BaseFloat mmt_gain = static_cast<BaseFloat>(1.0/(1.0-mMomentum));
    N *= mmt_gain;

    mBiasCorrection.AddColSum(1.0,GetErrorInput(),mMomentum);

    mBias.AddScaled(-mLearningRate/N,mBiasCorrection,1.0);

#endif
  }


  void
  CuUpdatableBias::
  ReadFromStream(std::istream& rIn)
  {
    //biases stored normally
    BfVector bias;
    rIn >> bias;
    mBias.CopyFrom(bias);

    /*if(bias.Dim() == 0) {
      Error("Missing bias vector in network file");
    }*/
    if( mBias.Dim() != GetNOutputs()
    ){
      std::ostringstream os;
      os << "Wrong dimensionalities of matrix/vector in network file\n"
         << "Inputs:" << GetNInputs()
         << "Outputs:" << GetNOutputs()
         << "\n"
         << "biasDims:" << mBias.Dim()
         << "\n";
      Error(os.str());
    }
  }

   
  void
  CuUpdatableBias::
  WriteToStream(std::ostream& rOut)
  {
    BfVector vec;
    mBias.CopyTo(vec);
    rOut << vec;
    rOut << std::endl;
  }

 
} //namespace