From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- .../.svn/text-base/SharedLinearity.h.svn-base | 103 +++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 src/TNetLib/.svn/text-base/SharedLinearity.h.svn-base (limited to 'src/TNetLib/.svn/text-base/SharedLinearity.h.svn-base') diff --git a/src/TNetLib/.svn/text-base/SharedLinearity.h.svn-base b/src/TNetLib/.svn/text-base/SharedLinearity.h.svn-base new file mode 100644 index 0000000..83feeee --- /dev/null +++ b/src/TNetLib/.svn/text-base/SharedLinearity.h.svn-base @@ -0,0 +1,103 @@ +#ifndef _CUSHARED_LINEARITY_H_ +#define _CUSHARED_LINEARITY_H_ + + +#include "Component.h" + +#include "Matrix.h" +#include "Vector.h" + + +namespace TNet { + +class SharedLinearity : public UpdatableComponent +{ + public: + SharedLinearity(size_t nInputs, size_t nOutputs, Component *pPred); + ~SharedLinearity(); + + ComponentType GetType() const + { return SHARED_LINEARITY; } + + const char* GetName() const + { return ""; } + + Component* Clone() const; + + void PropagateFnc(const Matrix& X, Matrix& Y); + void BackpropagateFnc(const Matrix& X, Matrix& Y); + + void ReadFromStream(std::istream& rIn); + void WriteToStream(std::ostream& rOut); + + /// calculate gradient + void Gradient(); + /// accumulate gradient from other components + void AccuGradient(const UpdatableComponent& src, int thr, int thrN); + /// update weights, reset the accumulator + void Update(int thr, int thrN); + +protected: + Matrix mLinearity; ///< Matrix with neuron weights + Vector mBias; ///< Vector with biases + + Matrix* mpLinearity; + Vector* mpBias; + + Matrix mLinearityCorrection; ///< Matrix for linearity updates + Vector mBiasCorrection; ///< Vector for bias updates + + Matrix mLinearityCorrectionAccu; ///< Accumulator for linearity updates + Vector mBiasCorrectionAccu; ///< Accumulator for bias updates + + int mNInstances; +}; + + + + +//////////////////////////////////////////////////////////////////////////// +// INLINE FUNCTIONS +// SharedLinearity:: +inline +SharedLinearity:: +SharedLinearity(size_t nInputs, size_t nOutputs, Component *pPred) + : UpdatableComponent(nInputs, nOutputs, pPred), + mpLinearity(&mLinearity), mpBias(&mBias), + mNInstances(0) +{ } + + +inline +SharedLinearity:: +~SharedLinearity() +{ } + + +inline +Component* +SharedLinearity:: +Clone() const +{ + SharedLinearity* ptr = new SharedLinearity(GetNInputs(),GetNOutputs(),NULL); + ptr->mpLinearity = mpLinearity; + ptr->mpBias = mpBias; + + ptr->mLinearityCorrection.Init(mpLinearity->Rows(),mpLinearity->Cols()); + ptr->mBiasCorrection.Init(mpBias->Dim()); + + ptr->mNInstances = mNInstances; + + ptr->mLearningRate = mLearningRate; + + + return ptr; +} + + + +} //namespace + + + +#endif -- cgit v1.2.3-70-g09d2