#ifndef _BIASED_LINEARITY_H_ #define _BIASED_LINEARITY_H_ #include "Component.h" #include "Matrix.h" #include "Vector.h" namespace TNet { class BiasedLinearity : public UpdatableComponent { public: BiasedLinearity(size_t nInputs, size_t nOutputs, Component *pPred); ~BiasedLinearity() { } ComponentType GetType() const { return BIASED_LINEARITY; } const char* GetName() const { return ""; } Component* Clone() const; void PropagateFnc(const Matrix& X, Matrix& Y); void BackpropagateFnc(const Matrix& X, Matrix& Y); void ReadFromStream(std::istream& rIn); void WriteToStream(std::ostream& rOut); /// calculate gradient void Gradient(); /// accumulate gradient from other components void AccuGradient(const UpdatableComponent& src, int thr, int thrN); /// update weights, reset the accumulator void Update(int thr, int thrN); protected: Matrix mLinearity; ///< Matrix with neuron weights Vector mBias; ///< Vector with biases const Matrix* mpLinearity; const Vector* mpBias; Matrix mLinearityCorrection; ///< Matrix for linearity updates Vector mBiasCorrection; ///< Vector for bias updates Matrix mLinearityCorrectionAccu; ///< Matrix for summing linearity updates Vector mBiasCorrectionAccu; ///< Vector for summing bias updates }; //////////////////////////////////////////////////////////////////////////// // INLINE FUNCTIONS // BiasedLinearity:: inline BiasedLinearity:: BiasedLinearity(size_t nInputs, size_t nOutputs, Component *pPred) : UpdatableComponent(nInputs, nOutputs, pPred), mLinearity(), mBias(), //cloned instaces don't need this mpLinearity(&mLinearity), mpBias(&mBias), mLinearityCorrection(nInputs,nOutputs), mBiasCorrection(nOutputs), mLinearityCorrectionAccu(), mBiasCorrectionAccu() //cloned instances don't need this { } inline Component* BiasedLinearity:: Clone() const { BiasedLinearity* ptr = new BiasedLinearity(GetNInputs(), GetNOutputs(), NULL); ptr->mpLinearity = mpLinearity; //copy pointer from currently active weights ptr->mpBias = mpBias; //... ptr->mLearningRate = mLearningRate; return ptr; } } //namespace #endif