From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- src/CuTNetLib/cuLinearity.h | 94 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 src/CuTNetLib/cuLinearity.h (limited to 'src/CuTNetLib/cuLinearity.h') diff --git a/src/CuTNetLib/cuLinearity.h b/src/CuTNetLib/cuLinearity.h new file mode 100644 index 0000000..050591d --- /dev/null +++ b/src/CuTNetLib/cuLinearity.h @@ -0,0 +1,94 @@ +#ifndef _CULINEARITY_H_ +#define _CULINEARITY_H_ + + +#include "cuComponent.h" +#include "cumatrix.h" + + +#include "Matrix.h" +#include "Vector.h" + + +namespace TNet { + /** + * \brief CuLinearity summation function + * + * \ingroup CuNNUpdatable + * Implements forward pass: \f[ Y_j=\Sigma_{i=0}^{i=N-1}w_ij X_i +{\beta}_j \f] + * Error propagation: \f[ E_i = \Sigma_{i=0}^{i=N-1} w_ij e_j \f] + * + * Weight adjustion: \f[ W_{ij} = (1-D)(w_{ij} - \alpha(1-\mu)x_i e_j - \mu \Delta) \f] + * where + * - D for weight decay => penalizing large weight + * - \f$ \alpha \f$ for learning rate + * - \f$ \mu \f$ for momentum => avoiding oscillation + */ + class CuLinearity : public CuUpdatableComponent + { + public: + + CuLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred); + ~CuLinearity(); + + ComponentType GetType() const; + const char* GetName() const; + + void PropagateFnc(const CuMatrix& X, CuMatrix& Y); + void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y); + + void Update(); + + void ReadFromStream(std::istream& rIn); + void WriteToStream(std::ostream& rOut); + + protected: + CuMatrix mLinearity; ///< Matrix with neuron weights + + CuMatrix mLinearityCorrection; ///< Matrix for linearity updates + + }; + + + + + //////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // CuLinearity:: + inline + CuLinearity:: + CuLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuUpdatableComponent(nInputs, nOutputs, pPred), + mLinearity(nInputs,nOutputs), + mLinearityCorrection(nInputs,nOutputs) + { + mLinearityCorrection.SetConst(0.0); + } + + + inline + CuLinearity:: + ~CuLinearity() + { } + + inline CuComponent::ComponentType + CuLinearity:: + GetType() const + { + return CuComponent::LINEARITY; + } + + inline const char* + CuLinearity:: + GetName() const + { + return ""; + } + + + +} //namespace + + + +#endif -- cgit v1.2.3-70-g09d2