#ifndef _NETWORK_COMPONENT_I_H #define _NETWORK_COMPONENT_I_H #include "Vector.h" #include "Matrix.h" #include #include namespace TNet { /** * Basic element of the network, * it is a box with defined inputs and outputs, * and functions to refresh outputs * * it is able to compute tranformation function (forward pass) * and jacobian function (backward pass), * which is to be implemented in descendents */ class Component { public: /// Types of the net components typedef enum { UPDATABLE_COMPONENT = 0x0100, BIASED_LINEARITY, SHARED_LINEARITY, ACT_FUN = 0x0200, SOFTMAX, SIGMOID, BLOCK_SOFTMAX, OTHER = 0x0400, EXPAND, COPY, TRANSPOSE, BLOCK_LINEARITY, WINDOW, BIAS, LOG, BLOCK_ARRAY, } ComponentType; ////////////////////////////////////////////////////////////// // Constructor & Destructor public: Component(size_t nInputs, size_t nOutputs, Component *pPred); virtual ~Component(); ////////////////////////////////////////////////////////////// // Interface specification (public) public: /// Get Type Identification of the component virtual ComponentType GetType() const = 0; /// Get Type Label of the component virtual const char* GetName() const = 0; /// virtual bool IsUpdatable() const { return false; } /// Clone the component virtual Component* Clone() const = 0; /// Get size of input vectors size_t GetNInputs() const; /// Get size of output vectors size_t GetNOutputs() const; /// IO Data getters const Matrix& GetInput() const; const Matrix& GetOutput() const; const Matrix& GetErrorInput() const; const Matrix& GetErrorOutput() const; /// Set input vector (bind with the preceding NetworkComponent) void SetInput(const Matrix& rInput); /// Set error input vector (bind with the following NetworkComponent) void SetErrorInput(const Matrix& rErrorInput); /// Perform forward pass propagateion Input->Output void Propagate(); /// Perform backward pass propagateion ErrorInput->ErrorOutput void Backpropagate(); /// Reads the component parameters from stream virtual void ReadFromStream(std::istream& rIn) { } /// Writes the components parameters to stream virtual void WriteToStream(std::ostream& rOut) { } /////////////////////////////////////////////////////////////// // Nonpublic member functions used to update data outputs protected: /// Forward pass transformation (to be implemented by descendents...) virtual void PropagateFnc(const Matrix& X, Matrix& Y) = 0; /// Backward pass transformation (to be implemented by descendents...) virtual void BackpropagateFnc(const Matrix& X, Matrix& Y) = 0; /////////////////////////////////////////////////////////////// // data members protected: size_t mNInputs; ///< Size of input vectors size_t mNOutputs; ///< Size of output vectors const Matrix* mpInput; ///< inputs are NOT OWNED by component const Matrix* mpErrorInput;///< inputs are NOT OWNED by component Matrix mOutput; ///< outputs are OWNED by component Matrix mErrorOutput; ///< outputs are OWNED by component }; /** * Class UpdatableComponent is a box which has some * parameters adjustable by learning * * you can set the learning rate, lock the params, * and learn from each data observation */ class UpdatableComponent : public Component { ////////////////////////////////////////////////////////////// // Constructor & Destructor public: UpdatableComponent(size_t nInputs, size_t nOutputs, Component *pPred); virtual ~UpdatableComponent(); ////////////////////////////////////////////////////////////// // Interface specification (public) public: /// virtual bool IsUpdatable() const { return true; } /// calculate gradient virtual void Gradient() = 0; /// accumulate gradient from other components virtual void AccuGradient(const UpdatableComponent& src, int thr, int thrN) = 0; /// update weights, reset the accumulator virtual void Update(int thr, int thrN) = 0; /// Sets the learning rate of gradient descent void LearnRate(BaseFloat rate); /// Gets the learning rate of gradient descent BaseFloat LearnRate() const; void Momentum(BaseFloat mmt); BaseFloat Momentum() const ; void Weightcost(BaseFloat cost); BaseFloat Weightcost() const; void Bunchsize(size_t size); size_t Bunchsize() const; protected: BaseFloat mLearningRate; BaseFloat mMomentum; BaseFloat mWeightcost; size_t mBunchsize; }; ////////////////////////////////////////////////////////////////////////// // INLINE FUNCTIONS // Component:: inline Component:: Component(size_t nInputs, size_t nOutputs, Component *pPred) : mNInputs(nInputs), mNOutputs(nOutputs), mpInput(NULL), mpErrorInput(NULL), mOutput(), mErrorOutput() { /* DOUBLE LINK the Components */ if (pPred != NULL) { SetInput(pPred->GetOutput()); pPred->SetErrorInput(GetErrorOutput()); } } inline Component:: ~Component() { ; } inline void Component:: Propagate() { //initialize output buffer if(mOutput.Rows() != GetInput().Rows() || mOutput.Cols() != GetNOutputs()) { mOutput.Init(GetInput().Rows(),GetNOutputs()); } //do the dimensionality test if(GetNInputs() != GetInput().Cols()) { KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs() << " Data dim: " << GetInput().Cols(); } //run transform PropagateFnc(GetInput(),mOutput); } inline void Component:: Backpropagate() { //re-initialize the output buffer if(mErrorOutput.Rows() != GetErrorInput().Rows() || mErrorOutput.Cols() != GetNInputs()) { mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs()); } //do the dimensionality test assert(GetErrorInput().Cols() == mNOutputs); assert(mErrorOutput.Cols() == mNInputs); assert(mErrorOutput.Rows() == GetErrorInput().Rows()); //transform BackpropagateFnc(GetErrorInput(),mErrorOutput); } inline void Component:: SetInput(const Matrix& rInput) { mpInput = &rInput; } inline void Component:: SetErrorInput(const Matrix& rErrorInput) { mpErrorInput = &rErrorInput; } inline const Matrix& Component:: GetInput() const { if (NULL == mpInput) Error("mpInput is NULL"); return *mpInput; } inline const Matrix& Component:: GetOutput() const { return mOutput; } inline const Matrix& Component:: GetErrorInput() const { if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); return *mpErrorInput; } inline const Matrix& Component:: GetErrorOutput() const { return mErrorOutput; } inline size_t Component:: GetNInputs() const { return mNInputs; } inline size_t Component:: GetNOutputs() const { return mNOutputs; } ////////////////////////////////////////////////////////////////////////// // INLINE FUNCTIONS // UpdatableComponent:: inline UpdatableComponent:: UpdatableComponent(size_t nInputs, size_t nOutputs, Component *pPred) : Component(nInputs, nOutputs, pPred), mLearningRate(0.0), mMomentum(0.0), mWeightcost(0.0), mBunchsize(0) { ; } inline UpdatableComponent:: ~UpdatableComponent() { ; } inline void UpdatableComponent:: LearnRate(BaseFloat rate) { mLearningRate = rate; } inline BaseFloat UpdatableComponent:: LearnRate() const { return mLearningRate; } inline void UpdatableComponent:: Momentum(BaseFloat mmt) { mMomentum = mmt; } inline BaseFloat UpdatableComponent:: Momentum() const { return mMomentum; } inline void UpdatableComponent:: Weightcost(BaseFloat cost) { mWeightcost = cost; } inline BaseFloat UpdatableComponent:: Weightcost() const { return mWeightcost; } inline void UpdatableComponent:: Bunchsize(size_t size) { mBunchsize = size; } inline size_t UpdatableComponent:: Bunchsize() const { return mBunchsize; } } // namespace TNet #endif