diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuComponent.h | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/CuTNetLib/cuComponent.h')
-rw-r--r-- | src/CuTNetLib/cuComponent.h | 505 |
1 files changed, 505 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuComponent.h b/src/CuTNetLib/cuComponent.h new file mode 100644 index 0000000..6cc8462 --- /dev/null +++ b/src/CuTNetLib/cuComponent.h @@ -0,0 +1,505 @@ +#ifndef _CUNETWORK_COMPONENT_I_H +#define _CUNETWORK_COMPONENT_I_H + + +#include "Vector.h" +#include "Matrix.h" +#include "Error.h" + +#include "cumatrix.h" + +#include <iostream> +#include <stdexcept> +#include <vector> + +/// \defgroup CuNNLayer CuNN Layer types +/// \ingroup CuNNComp + +/// \defgroup CuNNUpdatable CuNN Updatable Layer +/// \ingroup CuNNLayer + +/// \defgroup CuNNActivation CuNN Activation Func Layer +/// \ingroup CuNNLayer + +/// \defgroup CuNNMisc CuNN Misc Layer +/// \ingroup CuNNLayer + +namespace TNet { + + + /** + * \brief Neural network building blocks + * + * Basic element of the network, + * it is a box with defined inputs and outputs, + * and functions to refresh outputs + * + * it is able to compute tranformation function (forward pass) + * and jacobian function (backward pass), + * which is to be implemented in descendents + * + * Compoments together form a doubly-linked list. + * + * Every input rows are a frame of data, + * and the input size are the number of columns + */ + class CuComponent + { + public: + /// Types of the net components + typedef enum { + UPDATABLE_COMPONENT = 0x0100, + BIASED_LINEARITY, + DISCRETE_LINEARITY, + SHARED_LINEARITY, + SPARSE_LINEARITY, + RBM, + RBM_SPARSE, + RECURRENT, + LINEARITY, + UPDATABLEBIAS, + DISCRETE, + COMPOUND, + + ACT_FUN = 0x0200, + SOFTMAX, + SIGMOID, + + OTHER = 0x0400, + EXPAND, + COPY, + TRANSPOSE, + BLOCK_LINEARITY, + WINDOW, + BIAS, + LOG, + PIPE, + LEARNSTOP, + DISTRIB, + COMBINE, + DIVIDE, + MERGE, + REORDER, + + BLOCK_ARRAY, + } ComponentType; + + typedef std::vector< CuMatrix<BaseFloat>* > MatrixPtrVec; + + ////////////////////////////////////////////////////////////// + // Constructor & Destructor + public: + CuComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred); + virtual ~CuComponent(); + + ////////////////////////////////////////////////////////////// + // Interface specification (public) + public: + /// Get Type Identification of the component + virtual ComponentType GetType() const = 0; + /// Get Type Label of the component + virtual const char* GetName() const = 0; + /// Return if the component is UpdatableComponent + virtual bool IsUpdatable() const + { return false; } + + /// Get size of input vectors + size_t GetNInputs() const; + /// Get size of output vectors + size_t GetNOutputs() const; + + /// Set size of input vectors + size_t SetNInputs(size_t nInputs); + /// Set size of output vectors + size_t SetNOutputs(size_t nOutputs); + /// Set the previous component + void SetPrevious(CuComponent* pPred); + /// Set the next component + void SetNext(CuComponent* pNxt); + + /// Return the number of different inputs for complex component + int GetInSect(); + /// Return the number of different outputs for complex component + int GetOutSect(); + + /// IO Data getters + CuMatrix<BaseFloat>& GetInput(int pos=0); + CuMatrix<BaseFloat>& GetOutput(int pos=0); + CuMatrix<BaseFloat>& GetErrorInput(int pos=0); + CuMatrix<BaseFloat>& GetErrorOutput(int pos=0); + + /// Set input vector (bind with the preceding NetworkComponent) + void SetInput(CuMatrix<BaseFloat>& rInput,int pos=0); + /// Set error input vector (bind with the following NetworkComponent) + void SetErrorInput(CuMatrix<BaseFloat>& rErrorInput,int pos=0); + + /// Perform forward pass propagateion Input->Output, + /// wrapper for the PropagateFnc method + void Propagate(); + /// Perform backward pass propagateion ErrorInput->ErrorOutput, + /// wrapper for the BackpropagateFnc method + void Backpropagate(); + + /// Reads the component parameters from stream + virtual void ReadFromStream(std::istream& rIn) { } + /// Writes the components parameters to stream + virtual void WriteToStream(std::ostream& rOut) { } + + /// Public wrapper for PropagateFnc + void PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + /// Public wrapper for BackpropagateFnc + void BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + + + /////////////////////////////////////////////////////////////// + // Nonpublic member functions used to update data outputs + protected: + /// Forward pass transformation (to be implemented by descendents...) + /// \param[in] X InputMatrix (Network input or Output from last layer) + /// \param[out] Y OutputMatrix (Network output or input of the next layer) + virtual void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) = 0; + /// Backward pass transformation (to be implemented by descendents...) + /// \param[in] X InputMatrix (Network Error, objective func output, or Error output from the next layer) + /// \param[out] Y OutputMatrix (Error input of the last layer) + virtual void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) = 0; + + + /////////////////////////////////////////////////////////////// + // data members + protected: + + size_t mNInputs; ///< Size of input vectors + size_t mNOutputs; ///< Size of output vectors + + CuMatrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component + CuMatrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component + + CuMatrix<BaseFloat> mOutput; ///< outputs are OWNED by component + CuMatrix<BaseFloat> mErrorOutput; ///< outputs are OWNED by component + + CuComponent* preComp;///< The preceding component in the Network + CuComponent* nxtComp;///< The following component in the Network + }; + + + /** + * \brief Class UpdatableComponent is a box which has some + * parameters adjustable by learning + * + * you can set the learning rate, lock the params, + * and learn from each data observation + */ + class CuUpdatableComponent : public CuComponent + { + ////////////////////////////////////////////////////////////// + // Constructor & Destructor + public: + CuUpdatableComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred); + virtual ~CuUpdatableComponent(); + + + ////////////////////////////////////////////////////////////// + // Interface specification (public) + public: + /// Return if CuUpdatableComponent is updatable? + virtual bool IsUpdatable() const + { return true; } + + /// get gradient and update the parameters in one step + virtual void Update() = 0; + + /// Sets the learning rate of gradient descent + void LearnRate(BaseFloat rate); + /// Gets the learning rate of gradient descent + BaseFloat LearnRate(); + + /// Sets the momentum + void Momentum(BaseFloat mmt); + BaseFloat Momentum(); + + /// Set the weight decay rate to penalize large weights + void Weightcost(BaseFloat cost); + BaseFloat Weightcost(); + + /// Set whether gradient is divided by frames + void GradDivFrm(bool div); + bool GradDivFrm(); + + protected: + BaseFloat mLearningRate; + BaseFloat mMomentum; + BaseFloat mWeightcost; + bool mGradDivFrm; + + }; + + + + + ////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // CuComponent:: + inline + CuComponent:: + CuComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : mNInputs(nInputs), mNOutputs(nOutputs), + mpInput(NULL), mpErrorInput(NULL), + mOutput(), mErrorOutput(),preComp(pPred) + { + /* DOUBLE LINK the Components */ + if (pPred != NULL) { + SetPrevious(pPred); + pPred->SetNext(this); + } + } + + inline void + CuComponent:: + SetPrevious(CuComponent* pPred) + { + preComp=pPred; + /* DOUBLE LINK the Components */ + if (pPred != NULL) { + SetInput(pPred->GetOutput()); + } + } + + inline void + CuComponent:: + SetNext(CuComponent* pNxt) + { + nxtComp=pNxt; + if (pNxt != NULL) { + SetErrorInput(pNxt->GetErrorOutput()); + } + } + + inline + CuComponent:: + ~CuComponent() + { + ; + } + + inline void + CuComponent:: + Propagate() + { + //initialize output buffer + mOutput.Init(GetInput().Rows(),GetNOutputs()); + //do the dimensionality test + if(GetNInputs() != GetInput().Cols()) { + KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs() + << " Data dim: " << GetInput().Cols(); + } + //run transform + PropagateF(GetInput(),mOutput); + } + + + inline void + CuComponent:: + Backpropagate() + { + //re-initialize the output buffer + mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs()); + + //do the dimensionality test + assert(GetErrorInput().Cols() == mNOutputs); + assert(mErrorOutput.Cols() == mNInputs); + assert(mErrorOutput.Rows() == GetErrorInput().Rows()); + + //transform + BackpropagateF(GetErrorInput(),mErrorOutput); + } + + + inline void + CuComponent:: + SetInput(CuMatrix<BaseFloat>& rInput,int pos) + { + mpInput = &rInput; + } + + + inline void + CuComponent:: + SetErrorInput(CuMatrix<BaseFloat>& rErrorInput,int pos) + { + mpErrorInput = &rErrorInput; + } + + inline CuMatrix<BaseFloat>& + CuComponent:: + GetInput(int pos) + { + if (NULL == mpInput) Error("mpInput is NULL"); + return *mpInput; + } + + inline CuMatrix<BaseFloat>& + CuComponent:: + GetOutput(int pos) + { + return mOutput; + } + + inline CuMatrix<BaseFloat>& + CuComponent:: + GetErrorInput(int pos) + { + if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); + return *mpErrorInput; + } + + inline CuMatrix<BaseFloat>& + CuComponent:: + GetErrorOutput(int pos) + { + return mErrorOutput; + } + + inline size_t + CuComponent:: + GetNInputs() const + { + return mNInputs; + } + + inline size_t + CuComponent:: + GetNOutputs() const + { + return mNOutputs; + } + + inline int + CuComponent:: + GetInSect() + { + return 1; + } + + inline int + CuComponent:: + GetOutSect() + { + return 1; + } + + inline size_t + CuComponent:: + SetNInputs(size_t nInputs) + { + mNInputs=nInputs; + } + + inline size_t + CuComponent:: + SetNOutputs(size_t nOutputs) + { + mNOutputs=nOutputs; + } + + inline void + CuComponent:: + PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + PropagateFnc(X,Y); + } + inline void + CuComponent:: + BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + BackpropagateFnc(X,Y); + } + + + ////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // UpdatableComponent:: + + inline + CuUpdatableComponent:: + CuUpdatableComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuComponent(nInputs, nOutputs, pPred), + mLearningRate(0.0), mMomentum(0), mWeightcost(0), mGradDivFrm(true) + { + ; + } + + + inline + CuUpdatableComponent:: + ~CuUpdatableComponent() + { + ; + } + + + inline void + CuUpdatableComponent:: + LearnRate(BaseFloat rate) + { + mLearningRate = rate; + } + + + inline BaseFloat + CuUpdatableComponent:: + LearnRate() + { + return mLearningRate; + } + + + inline void + CuUpdatableComponent:: + Momentum(BaseFloat mmt) + { + mMomentum = mmt; + } + + + inline BaseFloat + CuUpdatableComponent:: + Momentum() + { + return mMomentum; + } + + + inline void + CuUpdatableComponent:: + Weightcost(BaseFloat cost) + { + mWeightcost = cost; + } + + + inline BaseFloat + CuUpdatableComponent:: + Weightcost() + { + return mWeightcost; + } + + + inline void + CuUpdatableComponent:: + GradDivFrm(bool div) + { + mGradDivFrm = div; + } + + inline bool + CuUpdatableComponent:: + GradDivFrm() + { + return mGradDivFrm; + } + +} // namespace TNet + + +#endif |