#ifndef _CUBIASED_LINEARITY_H_ #define _CUBIASED_LINEARITY_H_ #include "cuComponent.h" #include "cumatrix.h" #include "Matrix.h" #include "Vector.h" namespace TNet { class CuBiasedLinearity : public CuUpdatableComponent { public: CuBiasedLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred); ~CuBiasedLinearity(); ComponentType GetType() const; const char* GetName() const; void PropagateFnc(const CuMatrix& X, CuMatrix& Y); void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y); void Update(); void ReadFromStream(std::istream& rIn); void WriteToStream(std::ostream& rOut); protected: CuMatrix mLinearity; ///< Matrix with neuron weights CuVector mBias; ///< Vector with biases CuMatrix mLinearityCorrection; ///< Matrix for linearity updates CuVector mBiasCorrection; ///< Vector for bias updates }; //////////////////////////////////////////////////////////////////////////// // INLINE FUNCTIONS // CuBiasedLinearity:: inline CuBiasedLinearity:: CuBiasedLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred) : CuUpdatableComponent(nInputs, nOutputs, pPred), mLinearity(nInputs,nOutputs), mBias(nOutputs), mLinearityCorrection(nInputs,nOutputs), mBiasCorrection(nOutputs) { mLinearityCorrection.SetConst(0.0); mBiasCorrection.SetConst(0.0); } inline CuBiasedLinearity:: ~CuBiasedLinearity() { } inline CuComponent::ComponentType CuBiasedLinearity:: GetType() const { return CuComponent::BIASED_LINEARITY; } inline const char* CuBiasedLinearity:: GetName() const { return ""; } } //namespace #endif