diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuUpdatableBias.h | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/CuTNetLib/cuUpdatableBias.h')
-rw-r--r-- | src/CuTNetLib/cuUpdatableBias.h | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuUpdatableBias.h b/src/CuTNetLib/cuUpdatableBias.h new file mode 100644 index 0000000..df8066a --- /dev/null +++ b/src/CuTNetLib/cuUpdatableBias.h @@ -0,0 +1,109 @@ +#ifndef _CUUPDATABLE_BIAS_H_ +#define _CUUPDATABLE_BIAS_H_ + + +#include "cuComponent.h" +#include "cumatrix.h" + + +#include "Matrix.h" +#include "Vector.h" + + +namespace TNet { + /** + * \brief CuUpdatableBias summation function + * + * \ingroup CuNNUpdatable + * Implements forward pass: \f[ Y_i=X_i +{\beta}_i \f] + * Error propagation: \f[ E_i = e_i \f] + * + * Weight adjust: + * for bias: \f[ {\Beta}_i = {\beta}_i - \alpha(1-\mu)e_i - \mu \Delta \f] + * where + * - D for weight decay => penalizing large weight + * - \f$ \alpha \f$ for learning rate + * - \f$ \mu \f$ for momentum => avoiding oscillation + */ + class CuUpdatableBias : public CuUpdatableComponent + { + public: + + CuUpdatableBias(size_t nInputs, size_t nOutputs, CuComponent *pPred); + ~CuUpdatableBias(); + + ComponentType GetType() const; + const char* GetName() const; + + const CuMatrix<float>& GetErrorOutput(); + + void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + + void Update(); + + void ReadFromStream(std::istream& rIn); + void WriteToStream(std::ostream& rOut); + + void Backpropagate(); + + protected: + CuVector<BaseFloat> mBias; ///< Vector with biases + + CuVector<BaseFloat> mBiasCorrection; ///< Vector for bias updates + + }; + + + + + //////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // CuUpdatableBias:: + inline + CuUpdatableBias:: + CuUpdatableBias(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuUpdatableComponent(nInputs, nOutputs, pPred), + mBias(nOutputs), mBiasCorrection(nOutputs) + { + mBiasCorrection.SetConst(0.0); + } + + + inline + CuUpdatableBias:: + ~CuUpdatableBias() + { } + + inline CuComponent::ComponentType + CuUpdatableBias:: + GetType() const + { + return CuComponent::UPDATABLEBIAS; + } + + inline const char* + CuUpdatableBias:: + GetName() const + { + return "<updatablebias>"; + } + + inline void + CuUpdatableBias:: + Backpropagate() + { + } + + inline const CuMatrix<BaseFloat>& + CuUpdatableBias:: + GetErrorOutput() + { + return GetErrorInput(); + } + +} //namespace + + + +#endif |