#ifndef _CUNETWORK_COMPONENT_I_H #define _CUNETWORK_COMPONENT_I_H #include "Vector.h" #include "Matrix.h" #include "Error.h" #include "cumatrix.h" #include #include #include /// \defgroup CuNNLayer CuNN Layer types /// \ingroup CuNNComp /// \defgroup CuNNUpdatable CuNN Updatable Layer /// \ingroup CuNNLayer /// \defgroup CuNNActivation CuNN Activation Func Layer /// \ingroup CuNNLayer /// \defgroup CuNNMisc CuNN Misc Layer /// \ingroup CuNNLayer namespace TNet { /** * \brief Neural network building blocks * * Basic element of the network, * it is a box with defined inputs and outputs, * and functions to refresh outputs * * it is able to compute tranformation function (forward pass) * and jacobian function (backward pass), * which is to be implemented in descendents * * Compoments together form a doubly-linked list. * * Every input rows are a frame of data, * and the input size are the number of columns */ class CuComponent { private: static int UID_CNT; public: /// Unique id of this component int UID; /// Types of the net components typedef enum { UPDATABLE_COMPONENT = 0x0100, BIASED_LINEARITY, DISCRETE_LINEARITY, SHARED_LINEARITY, SPARSE_LINEARITY, RBM, RBM_SPARSE, RECURRENT, LINEARITY, UPDATABLEBIAS, DISCRETE, COMPOUND, ACT_FUN = 0x0200, SOFTMAX, SIGMOID, OTHER = 0x0400, EXPAND, COPY, TRANSPOSE, BLOCK_LINEARITY, WINDOW, BIAS, LOG, PIPE, LEARNSTOP, DISTRIB, COMBINE, DIVIDE, MERGE, REORDER, BLOCK_ARRAY, } ComponentType; typedef std::vector< CuMatrix* > MatrixPtrVec; typedef std::vector< const CuMatrix* > ConstMatrixPtrVec; ////////////////////////////////////////////////////////////// // Constructor & Destructor public: CuComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred); virtual ~CuComponent(); ////////////////////////////////////////////////////////////// // Interface specification (public) public: /// Get Type Identification of the component virtual ComponentType GetType() const = 0; /// Get Type Label of the component virtual const char* GetName() const = 0; /// Return if the component is UpdatableComponent virtual bool IsUpdatable() const { return false; } /// Get size of input vectors size_t GetNInputs() const; /// Get size of output vectors size_t GetNOutputs() const; /// Set size of input vectors size_t SetNInputs(size_t nInputs); /// Set size of output vectors size_t SetNOutputs(size_t nOutputs); /// Set the previous component void SetPrevious(CuComponent* pPred); /// Set the next component void SetNext(CuComponent* pNxt); /// Return the number of different inputs for complex component virtual int GetInSect(){return 1;} /// Return the number of different outputs for complex component virtual int GetOutSect(){return 1;} /// IO Data getters virtual const CuMatrix& GetInput(int pos=0) { if (NULL == mpInput) Error("mpInput is NULL"); return *mpInput; } virtual const CuMatrix& GetOutput(int pos=0) { return mOutput; } virtual const CuMatrix& GetErrorInput(int pos=0) { if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); return *mpErrorInput; } virtual const CuMatrix& GetErrorOutput(int pos=0) { return mErrorOutput; } /// Set input vector (bind with the preceding NetworkComponent) virtual void SetInput(const CuMatrix& rInput,int pos=0) { mpInput = &rInput; } /// Set error input vector (bind with the following NetworkComponent) virtual void SetErrorInput(const CuMatrix& rErrorInput,int pos=0) { mpErrorInput = &rErrorInput; } /// Perform forward pass propagateion Input->Output, /// wrapper for the PropagateFnc method virtual void Propagate() { //initialize output buffer mOutput.Init(GetInput().Rows(),GetNOutputs()); //do the dimensionality test if(GetNInputs() != GetInput().Cols()) { KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs() << " Data dim: " << GetInput().Cols(); } //run transform PropagateF(GetInput(),mOutput); } /// Perform backward pass propagateion ErrorInput->ErrorOutput, /// wrapper for the BackpropagateFnc method virtual void Backpropagate() { //re-initialize the output buffer mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs()); //do the dimensionality test assert(GetErrorInput().Cols() == mNOutputs); assert(mErrorOutput.Cols() == mNInputs); assert(mErrorOutput.Rows() == GetErrorInput().Rows()); //transform BackpropagateF(GetErrorInput(),mErrorOutput); } /// Reads the component parameters from stream virtual void ReadFromStream(std::istream& rIn) { } /// Writes the components parameters to stream virtual void WriteToStream(std::ostream& rOut) { } /// Public wrapper for PropagateFnc void PropagateF(const CuMatrix& X, CuMatrix& Y); /// Public wrapper for BackpropagateFnc void BackpropagateF(const CuMatrix& X, CuMatrix& Y); /////////////////////////////////////////////////////////////// // Nonpublic member functions used to update data outputs protected: /// Forward pass transformation (to be implemented by descendents...) /// \param[in] X InputMatrix (Network input or Output from last layer) /// \param[out] Y OutputMatrix (Network output or input of the next layer) virtual void PropagateFnc(const CuMatrix& X, CuMatrix& Y) = 0; /// Backward pass transformation (to be implemented by descendents...) /// \param[in] X InputMatrix (Network Error, objective func output, or Error output from the next layer) /// \param[out] Y OutputMatrix (Error input of the last layer) virtual void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) = 0; /////////////////////////////////////////////////////////////// // data members protected: size_t mNInputs; ///< Size of input vectors size_t mNOutputs; ///< Size of output vectors const CuMatrix* mpInput; ///< inputs are NOT OWNED by component const CuMatrix* mpErrorInput;///< inputs are NOT OWNED by component CuMatrix mOutput; ///< outputs are OWNED by component CuMatrix mErrorOutput; ///< outputs are OWNED by component CuComponent* preComp;///< The preceding component in the Network CuComponent* nxtComp;///< The following component in the Network }; /** * \brief Class UpdatableComponent is a box which has some * parameters adjustable by learning * * you can set the learning rate, lock the params, * and learn from each data observation */ class CuUpdatableComponent : public CuComponent { ////////////////////////////////////////////////////////////// // Constructor & Destructor public: CuUpdatableComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred); virtual ~CuUpdatableComponent(); ////////////////////////////////////////////////////////////// // Interface specification (public) public: /// Return if CuUpdatableComponent is updatable? virtual bool IsUpdatable() const { return true; } /// get gradient and update the parameters in one step virtual void Update() = 0; /// Sets the learning rate of gradient descent void LearnRate(BaseFloat rate); /// Gets the learning rate of gradient descent BaseFloat LearnRate(); /// Sets the momentum void Momentum(BaseFloat mmt); BaseFloat Momentum(); /// Set the weight decay rate to penalize large weights void Weightcost(BaseFloat cost); BaseFloat Weightcost(); /// Set whether gradient is divided by frames void GradDivFrm(bool div); bool GradDivFrm(); protected: BaseFloat mLearningRate; BaseFloat mMomentum; BaseFloat mWeightcost; bool mGradDivFrm; }; ////////////////////////////////////////////////////////////////////////// // INLINE FUNCTIONS // CuComponent:: inline CuComponent:: CuComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred) : mNInputs(nInputs), mNOutputs(nOutputs), mpInput(NULL), mpErrorInput(NULL), mOutput(), mErrorOutput(),preComp(pPred) { UID=++(CuComponent::UID_CNT); /* DOUBLE LINK the Components */ if (pPred != NULL) { SetPrevious(pPred); pPred->SetNext(this); } } inline void CuComponent:: SetPrevious(CuComponent* pPred) { preComp=pPred; /* DOUBLE LINK the Components */ if (pPred != NULL) { SetInput(pPred->GetOutput()); } } inline void CuComponent:: SetNext(CuComponent* pNxt) { nxtComp=pNxt; if (pNxt != NULL) { SetErrorInput(pNxt->GetErrorOutput()); } } inline CuComponent:: ~CuComponent() { ; } inline size_t CuComponent:: GetNInputs() const { return mNInputs; } inline size_t CuComponent:: GetNOutputs() const { return mNOutputs; } inline size_t CuComponent:: SetNInputs(size_t nInputs) { mNInputs=nInputs; } inline size_t CuComponent:: SetNOutputs(size_t nOutputs) { mNOutputs=nOutputs; } inline void CuComponent:: PropagateF(const CuMatrix& X, CuMatrix& Y) { PropagateFnc(X,Y); } inline void CuComponent:: BackpropagateF(const CuMatrix& X, CuMatrix& Y) { BackpropagateFnc(X,Y); } ////////////////////////////////////////////////////////////////////////// // INLINE FUNCTIONS // UpdatableComponent:: inline CuUpdatableComponent:: CuUpdatableComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred) : CuComponent(nInputs, nOutputs, pPred), mLearningRate(0.0), mMomentum(0), mWeightcost(0), mGradDivFrm(true) { ; } inline CuUpdatableComponent:: ~CuUpdatableComponent() { ; } inline void CuUpdatableComponent:: LearnRate(BaseFloat rate) { mLearningRate = rate; } inline BaseFloat CuUpdatableComponent:: LearnRate() { return mLearningRate; } inline void CuUpdatableComponent:: Momentum(BaseFloat mmt) { mMomentum = mmt; } inline BaseFloat CuUpdatableComponent:: Momentum() { return mMomentum; } inline void CuUpdatableComponent:: Weightcost(BaseFloat cost) { mWeightcost = cost; } inline BaseFloat CuUpdatableComponent:: Weightcost() { return mWeightcost; } inline void CuUpdatableComponent:: GradDivFrm(bool div) { mGradDivFrm = div; } inline bool CuUpdatableComponent:: GradDivFrm() { return mGradDivFrm; } } // namespace TNet #endif