summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuUpdatableBias.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/CuTNetLib/cuUpdatableBias.cc')
-rw-r--r--src/CuTNetLib/cuUpdatableBias.cc96
1 files changed, 96 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuUpdatableBias.cc b/src/CuTNetLib/cuUpdatableBias.cc
new file mode 100644
index 0000000..2a9cbed
--- /dev/null
+++ b/src/CuTNetLib/cuUpdatableBias.cc
@@ -0,0 +1,96 @@
+
+
+#include "cuUpdatableBias.h"
+
+
+namespace TNet
+{
+
+ void
+ CuUpdatableBias::
+ PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ //Y.SetConst(0.0);
+ Y.AddScaledRow(1.0,mBias,0.0);
+ Y.AddScaled(1.0,X,1.0);
+ }
+
+
+ void
+ CuUpdatableBias::
+ BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ //Y.SetConst(0.0);
+ Y.CopyFrom(X);
+ }
+
+
+ void
+ CuUpdatableBias::
+ Update()
+ {
+#if 0
+ //former implementation
+ BaseFloat N = static_cast<BaseFloat>(GetInput().Rows());
+
+ mBiasCorrection.AddColSum(-mLearningRate/N,GetErrorInput(),mMomentum);
+
+ mBias.AddScaled(1.0,mBiasCorrection,1.0);
+#endif
+
+#if 1
+ //new implementation
+ BaseFloat N = 1;
+ if(mGradDivFrm) {
+ N = static_cast<BaseFloat>(GetInput().Rows());
+ }
+ BaseFloat mmt_gain = static_cast<BaseFloat>(1.0/(1.0-mMomentum));
+ N *= mmt_gain;
+
+ mBiasCorrection.AddColSum(1.0,GetErrorInput(),mMomentum);
+
+ mBias.AddScaled(-mLearningRate/N,mBiasCorrection,1.0);
+
+#endif
+ }
+
+
+ void
+ CuUpdatableBias::
+ ReadFromStream(std::istream& rIn)
+ {
+ //biases stored normally
+ BfVector bias;
+ rIn >> bias;
+ mBias.CopyFrom(bias);
+
+ /*if(bias.Dim() == 0) {
+ Error("Missing bias vector in network file");
+ }*/
+ if( mBias.Dim() != GetNOutputs()
+ ){
+ std::ostringstream os;
+ os << "Wrong dimensionalities of matrix/vector in network file\n"
+ << "Inputs:" << GetNInputs()
+ << "Outputs:" << GetNOutputs()
+ << "\n"
+ << "biasDims:" << mBias.Dim()
+ << "\n";
+ Error(os.str());
+ }
+ }
+
+
+ void
+ CuUpdatableBias::
+ WriteToStream(std::ostream& rOut)
+ {
+ BfVector vec;
+ mBias.CopyTo(vec);
+ rOut << vec;
+ rOut << std::endl;
+ }
+
+
+} //namespace
+