summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuSharedLinearity.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/CuTNetLib/cuSharedLinearity.cc')
-rw-r--r--src/CuTNetLib/cuSharedLinearity.cc179
1 files changed, 179 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuSharedLinearity.cc b/src/CuTNetLib/cuSharedLinearity.cc
new file mode 100644
index 0000000..8d5ec09
--- /dev/null
+++ b/src/CuTNetLib/cuSharedLinearity.cc
@@ -0,0 +1,179 @@
+
+
+#include "cuSharedLinearity.h"
+#include "cumath.h"
+
+
+namespace TNet
+{
+
+ void
+ CuSharedLinearity::
+ PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ CuMath<BaseFloat>::VecExpand(mBias,mBiasExpand); /// [ 1 2 3 ] -> [ 1 2 3 1 2 3 ... ]
+ Y.AddScaledRow(1.0,mBiasExpand,0.0);
+
+ //mBiasExpand.Print();
+
+ for(int i=0; i<mNInstances; i++) {
+ CuMath<BaseFloat>::OffsetGemm('N','N', 1.0, X, mLinearity, 1.0, Y,
+ i*mLinearity.Rows(), 0, i*mLinearity.Cols());
+ }
+ //std::cout << CuDevice::Instantiate().GetFreeMemory();
+ //GetInput().Print();
+ //GetOutput().Print();
+ }
+
+
+ void
+ CuSharedLinearity::
+ BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ for(int i=0; i<mNInstances; i++) {
+ CuMath<BaseFloat>::OffsetGemm('N', 'T', 1.0, X, mLinearity, 0.0, Y,
+ i*mLinearity.Cols(), 0, i*mLinearity.Rows());
+ }
+ }
+
+
+ void
+ CuSharedLinearity::
+ Update()
+ {
+#if 0
+ //former implementation
+ BaseFloat N = static_cast<BaseFloat>(GetInput().Rows());
+
+ for(int i=0; i<mNInstances; i++) {
+ CuMath<BaseFloat>::OffsetGemm('T','N',-mLearningRate/(N*mNInstances),
+ GetInput(),GetErrorInput(),
+ ((i==0)?mMomentum:1.0f), mLinearityCorrection,
+ i*mLinearity.Rows(),i*mLinearity.Cols(),0);
+ }
+ mBiasCorrectionExpand.AddColSum(1.0,GetErrorInput(),0.0);
+ CuMath<BaseFloat>::VecAddColSum(-mLearningRate/(N*mNInstances),mBiasCorrectionExpand,mMomentum,mBiasCorrection);
+
+
+ //regularization weight decay
+ mLinearityCorrection.AddScaled(-mLearningRate*mWeightcost,mLinearity,1.0);
+
+ mLinearity.AddScaled(1.0,mLinearityCorrection,1.0);
+ mBias.AddScaled(1.0,mBiasCorrection,1.0);
+#endif
+
+#if 1
+ //new implementation
+ BaseFloat N = 1;
+ if(mGradDivFrm) {
+ N = static_cast<BaseFloat>(GetInput().Rows());
+ }
+ BaseFloat mmt_gain = static_cast<BaseFloat>(1.0/(1.0-mMomentum));
+ N *= mmt_gain; //compensate higher gradient estimates due to momentum
+
+ //compensate augmented dyn. range of gradient caused by multiple instances
+ N *= static_cast<BaseFloat>(mNInstances);
+
+ //get gradient of shared linearity
+ for(int i=0; i<mNInstances; i++) {
+ CuMath<BaseFloat>::OffsetGemm('T','N',1.0,
+ GetInput(),GetErrorInput(),
+ ((i==0)?mMomentum:1.0f), mLinearityCorrection,
+ i*mLinearity.Rows(),i*mLinearity.Cols(),0);
+ }
+ //get gradient of shared bias
+ mBiasCorrectionExpand.AddColSum(1.0,GetErrorInput(),0.0);
+ CuMath<BaseFloat>::VecAddColSum(1.0,mBiasCorrectionExpand,mMomentum,mBiasCorrection);
+
+ //perform update
+ mLinearity.AddScaled(-mLearningRate/N,mLinearityCorrection,1.0);
+ mBias.AddScaled(-mLearningRate/N,mBiasCorrection,1.0);
+
+ //regularization weight decay
+ mLinearity.AddScaled(-mLearningRate*mWeightcost,mLinearity,1.0);
+#endif
+
+ }
+
+
+ void
+ CuSharedLinearity::
+ ReadFromStream(std::istream& rIn)
+ {
+ //number of instances of shared weights in layer
+ rIn >> std::ws >> mNInstances;
+ if(mNInstances < 1) {
+ std::ostringstream os;
+ os << "Bad number of instances:" << mNInstances;
+ Error(os.str());
+ }
+ if(GetNInputs() % mNInstances != 0 || GetNOutputs() % mNInstances != 0) {
+ std::ostringstream os;
+ os << "Number of Inputs/Outputs must be divisible by number of instances"
+ << " Inputs:" << GetNInputs()
+ << " Outputs" << GetNOutputs()
+ << " Intances:" << mNInstances;
+ Error(os.str());
+ }
+
+ //matrix is stored transposed as SNet does
+ BfMatrix transpose;
+ rIn >> transpose;
+ mLinearity.CopyFrom(BfMatrix(transpose, TRANS));
+ //biases stored normally
+ BfVector bias;
+ rIn >> bias;
+ mBias.CopyFrom(bias);
+
+ if(transpose.Cols()*transpose.Rows() == 0) {
+ Error("Missing linearity matrix in network file");
+ }
+ if(bias.Dim() == 0) {
+ Error("Missing bias vector in network file");
+ }
+
+
+ if(mLinearity.Cols() != GetNOutputs() / mNInstances ||
+ mLinearity.Rows() != GetNInputs() / mNInstances ||
+ mBias.Dim() != GetNOutputs() / mNInstances
+ ){
+ std::ostringstream os;
+ os << "Wrong dimensionalities of matrix/vector in network file\n"
+ << "Inputs:" << GetNInputs()
+ << "Outputs:" << GetNOutputs()
+ << "\n"
+ << "linearityCols:" << mLinearity.Cols()
+ << "linearityRows:" << mLinearity.Rows()
+ << "biasDims:" << mBias.Dim()
+ << "\n";
+ Error(os.str());
+ }
+
+ mLinearityCorrection.Init(mLinearity.Rows(),mLinearity.Cols());
+ mBiasCorrection.Init(mBias.Dim());
+
+ mBiasExpand.Init(mBias.Dim()*mNInstances);
+ mBiasCorrectionExpand.Init(mBias.Dim()*mNInstances);
+ }
+
+
+ void
+ CuSharedLinearity::
+ WriteToStream(std::ostream& rOut)
+ {
+ rOut << mNInstances << std::endl;
+
+ //matrix is stored transposed as SNet does
+ BfMatrix tmp;
+ mLinearity.CopyTo(tmp);
+ BfMatrix transpose(tmp, TRANS);
+ rOut << transpose;
+ //biases stored normally
+ BfVector vec;
+ mBias.CopyTo(vec);
+ rOut << vec;
+ rOut << std::endl;
+ }
+
+
+} //namespace