summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuUpdatableBias.h
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
commitcccccbf6cca94a3eaf813b4468453160e91c332b (patch)
tree23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuUpdatableBias.h
downloadtnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip
First commit
Diffstat (limited to 'src/CuTNetLib/cuUpdatableBias.h')
-rw-r--r--src/CuTNetLib/cuUpdatableBias.h109
1 files changed, 109 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuUpdatableBias.h b/src/CuTNetLib/cuUpdatableBias.h
new file mode 100644
index 0000000..df8066a
--- /dev/null
+++ b/src/CuTNetLib/cuUpdatableBias.h
@@ -0,0 +1,109 @@
+#ifndef _CUUPDATABLE_BIAS_H_
+#define _CUUPDATABLE_BIAS_H_
+
+
+#include "cuComponent.h"
+#include "cumatrix.h"
+
+
+#include "Matrix.h"
+#include "Vector.h"
+
+
+namespace TNet {
+ /**
+ * \brief CuUpdatableBias summation function
+ *
+ * \ingroup CuNNUpdatable
+ * Implements forward pass: \f[ Y_i=X_i +{\beta}_i \f]
+ * Error propagation: \f[ E_i = e_i \f]
+ *
+ * Weight adjust:
+ * for bias: \f[ {\Beta}_i = {\beta}_i - \alpha(1-\mu)e_i - \mu \Delta \f]
+ * where
+ * - D for weight decay => penalizing large weight
+ * - \f$ \alpha \f$ for learning rate
+ * - \f$ \mu \f$ for momentum => avoiding oscillation
+ */
+ class CuUpdatableBias : public CuUpdatableComponent
+ {
+ public:
+
+ CuUpdatableBias(size_t nInputs, size_t nOutputs, CuComponent *pPred);
+ ~CuUpdatableBias();
+
+ ComponentType GetType() const;
+ const char* GetName() const;
+
+ const CuMatrix<float>& GetErrorOutput();
+
+ void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+
+ void Update();
+
+ void ReadFromStream(std::istream& rIn);
+ void WriteToStream(std::ostream& rOut);
+
+ void Backpropagate();
+
+ protected:
+ CuVector<BaseFloat> mBias; ///< Vector with biases
+
+ CuVector<BaseFloat> mBiasCorrection; ///< Vector for bias updates
+
+ };
+
+
+
+
+ ////////////////////////////////////////////////////////////////////////////
+ // INLINE FUNCTIONS
+ // CuUpdatableBias::
+ inline
+ CuUpdatableBias::
+ CuUpdatableBias(size_t nInputs, size_t nOutputs, CuComponent *pPred)
+ : CuUpdatableComponent(nInputs, nOutputs, pPred),
+ mBias(nOutputs), mBiasCorrection(nOutputs)
+ {
+ mBiasCorrection.SetConst(0.0);
+ }
+
+
+ inline
+ CuUpdatableBias::
+ ~CuUpdatableBias()
+ { }
+
+ inline CuComponent::ComponentType
+ CuUpdatableBias::
+ GetType() const
+ {
+ return CuComponent::UPDATABLEBIAS;
+ }
+
+ inline const char*
+ CuUpdatableBias::
+ GetName() const
+ {
+ return "<updatablebias>";
+ }
+
+ inline void
+ CuUpdatableBias::
+ Backpropagate()
+ {
+ }
+
+ inline const CuMatrix<BaseFloat>&
+ CuUpdatableBias::
+ GetErrorOutput()
+ {
+ return GetErrorInput();
+ }
+
+} //namespace
+
+
+
+#endif