summaryrefslogtreecommitdiff
path: root/src/TNetLib/BiasedLinearity.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/TNetLib/BiasedLinearity.cc')
-rw-r--r--src/TNetLib/BiasedLinearity.cc180
1 files changed, 180 insertions, 0 deletions
diff --git a/src/TNetLib/BiasedLinearity.cc b/src/TNetLib/BiasedLinearity.cc
new file mode 100644
index 0000000..b52aeb0
--- /dev/null
+++ b/src/TNetLib/BiasedLinearity.cc
@@ -0,0 +1,180 @@
+
+
+#include "BiasedLinearity.h"
+
+
+namespace TNet {
+
+
+void
+BiasedLinearity::
+PropagateFnc(const Matrix<BaseFloat>& X, Matrix<BaseFloat>& Y)
+{
+ //y = b + x.A
+
+ //precopy bias
+ size_t rows = X.Rows();
+ for(size_t i=0; i<rows; i++) {
+ Y[i].Copy(*mpBias);
+ }
+
+ //multiply matrix by matrix with mLinearity
+ Y.BlasGemm(1.0f, X, NO_TRANS, *mpLinearity, NO_TRANS, 1.0f);
+}
+
+
+void
+BiasedLinearity::
+BackpropagateFnc(const Matrix<BaseFloat>& X, Matrix<BaseFloat>& Y)
+{
+ // e' = e.A^T
+ Y.Zero();
+ Y.BlasGemm(1.0f, X, NO_TRANS, *mpLinearity, TRANS, 0.0f);
+}
+
+
+
+void
+BiasedLinearity::
+ReadFromStream(std::istream& rIn)
+{
+ //matrix is stored transposed as SNet does
+ Matrix<BaseFloat> transpose;
+ rIn >> transpose;
+ mLinearity = Matrix<BaseFloat>(transpose, TRANS);
+ //biases stored normally
+ rIn >> mBias;
+}
+
+
+void
+BiasedLinearity::
+WriteToStream(std::ostream& rOut)
+{
+ //matrix is stored transposed as SNet does
+ Matrix<BaseFloat> transpose(mLinearity, TRANS);
+ rOut << transpose;
+ //biases stored normally
+ rOut << mBias;
+ rOut << std::endl;
+}
+
+
+void
+BiasedLinearity::
+Gradient()
+{
+ //calculate gradient of weight matrix
+ mLinearityCorrection.Zero();
+ mLinearityCorrection.BlasGemm(1.0f, GetInput(), TRANS,
+ GetErrorInput(), NO_TRANS,
+ 0.0f);
+
+ //calculate gradient of bias
+ mBiasCorrection.Set(0.0f);
+ size_t rows = GetInput().Rows();
+ for(size_t i=0; i<rows; i++) {
+ mBiasCorrection.Add(GetErrorInput()[i]);
+ }
+
+ /*
+ //perform update
+ mLinearity.AddScaled(-mLearningRate, mLinearityCorrection);
+ mBias.AddScaled(-mLearningRate, mBiasCorrection);
+ */
+}
+
+
+void
+BiasedLinearity::
+AccuGradient(const UpdatableComponent& src, int thr, int thrN) {
+ //cast the argument
+ const BiasedLinearity& src_comp = dynamic_cast<const BiasedLinearity&>(src);
+
+ //allocate accumulators when needed
+ if(mLinearityCorrectionAccu.MSize() == 0) {
+ mLinearityCorrectionAccu.Init(mLinearity.Rows(),mLinearity.Cols());
+ }
+ if(mBiasCorrectionAccu.MSize() == 0) {
+ mBiasCorrectionAccu.Init(mBias.Dim());
+ }
+
+ //need to find out which rows to sum...
+ int div = mLinearityCorrection.Rows() / thrN;
+ int mod = mLinearityCorrection.Rows() % thrN;
+
+ int origin = thr * div + ((mod > thr)? thr : mod);
+ int rows = div + ((mod > thr)? 1 : 0);
+
+ //create the matrix windows
+ const SubMatrix<BaseFloat> src_mat (
+ src_comp.mLinearityCorrection,
+ origin, rows,
+ 0, mLinearityCorrection.Cols()
+ );
+ SubMatrix<double> tgt_mat (
+ mLinearityCorrectionAccu,
+ origin, rows,
+ 0, mLinearityCorrection.Cols()
+ );
+ //sum the rows
+ Add(tgt_mat,src_mat);
+
+ //first thread will always sum the bias correction
+ if(thr == 0) {
+ Add(mBiasCorrectionAccu,src_comp.mBiasCorrection);
+ }
+
+}
+
+
+void
+BiasedLinearity::
+Update(int thr, int thrN)
+{
+ //need to find out which rows to sum...
+ int div = mLinearity.Rows() / thrN;
+ int mod = mLinearity.Rows() % thrN;
+
+ int origin = thr * div + ((mod > thr)? thr : mod);
+ int rows = div + ((mod > thr)? 1 : 0);
+
+ //std::cout << "[P" << thr << "," << origin << "," << rows << "]" << std::flush;
+
+ //get the matrix windows
+ SubMatrix<double> src_mat (
+ mLinearityCorrectionAccu,
+ origin, rows,
+ 0, mLinearityCorrection.Cols()
+ );
+ SubMatrix<BaseFloat> tgt_mat (
+ mLinearity,
+ origin, rows,
+ 0, mLinearityCorrection.Cols()
+ );
+
+
+ //update weights
+ AddScaled(tgt_mat, src_mat, -mLearningRate);
+
+ //perform L2 regularization (weight decay)
+ BaseFloat L2_decay = -mLearningRate * mWeightcost * mBunchsize;
+ if(L2_decay != 0.0) {
+ tgt_mat.AddScaled(L2_decay, tgt_mat);
+ }
+
+ //first thread always update bias
+ if(thr == 0) {
+ //std::cout << "[" << thr << "BP]" << std::flush;
+ AddScaled(mBias, mBiasCorrectionAccu, -mLearningRate);
+ }
+
+ //reset the accumulators
+ src_mat.Zero();
+ if(thr == 0) {
+ mBiasCorrectionAccu.Zero();
+ }
+
+}
+
+} //namespace