diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/TNetLib/Component.h | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/TNetLib/Component.h')
-rw-r--r-- | src/TNetLib/Component.h | 387 |
1 files changed, 387 insertions, 0 deletions
diff --git a/src/TNetLib/Component.h b/src/TNetLib/Component.h new file mode 100644 index 0000000..762451e --- /dev/null +++ b/src/TNetLib/Component.h @@ -0,0 +1,387 @@ +#ifndef _NETWORK_COMPONENT_I_H +#define _NETWORK_COMPONENT_I_H + + +#include "Vector.h" +#include "Matrix.h" + +#include <iostream> +#include <stdexcept> + + +namespace TNet { + + + /** + * Basic element of the network, + * it is a box with defined inputs and outputs, + * and functions to refresh outputs + * + * it is able to compute tranformation function (forward pass) + * and jacobian function (backward pass), + * which is to be implemented in descendents + */ + class Component + { + public: + /// Types of the net components + typedef enum { + UPDATABLE_COMPONENT = 0x0100, + BIASED_LINEARITY, + SHARED_LINEARITY, + + ACT_FUN = 0x0200, + SOFTMAX, + SIGMOID, + BLOCK_SOFTMAX, + + OTHER = 0x0400, + EXPAND, + COPY, + TRANSPOSE, + BLOCK_LINEARITY, + WINDOW, + BIAS, + LOG, + + BLOCK_ARRAY, + } ComponentType; + + + ////////////////////////////////////////////////////////////// + // Constructor & Destructor + public: + Component(size_t nInputs, size_t nOutputs, Component *pPred); + virtual ~Component(); + + ////////////////////////////////////////////////////////////// + // Interface specification (public) + public: + /// Get Type Identification of the component + virtual ComponentType GetType() const = 0; + /// Get Type Label of the component + virtual const char* GetName() const = 0; + /// + virtual bool IsUpdatable() const + { return false; } + /// Clone the component + virtual Component* Clone() const = 0; + + /// Get size of input vectors + size_t GetNInputs() const; + /// Get size of output vectors + size_t GetNOutputs() const; + + /// IO Data getters + const Matrix<BaseFloat>& GetInput() const; + const Matrix<BaseFloat>& GetOutput() const; + const Matrix<BaseFloat>& GetErrorInput() const; + const Matrix<BaseFloat>& GetErrorOutput() const; + + /// Set input vector (bind with the preceding NetworkComponent) + void SetInput(const Matrix<BaseFloat>& rInput); + /// Set error input vector (bind with the following NetworkComponent) + void SetErrorInput(const Matrix<BaseFloat>& rErrorInput); + + /// Perform forward pass propagateion Input->Output + void Propagate(); + /// Perform backward pass propagateion ErrorInput->ErrorOutput + void Backpropagate(); + + /// Reads the component parameters from stream + virtual void ReadFromStream(std::istream& rIn) { } + /// Writes the components parameters to stream + virtual void WriteToStream(std::ostream& rOut) { } + + + /////////////////////////////////////////////////////////////// + // Nonpublic member functions used to update data outputs + protected: + /// Forward pass transformation (to be implemented by descendents...) + virtual void PropagateFnc(const Matrix<BaseFloat>& X, Matrix<BaseFloat>& Y) = 0; + /// Backward pass transformation (to be implemented by descendents...) + virtual void BackpropagateFnc(const Matrix<BaseFloat>& X, Matrix<BaseFloat>& Y) = 0; + + + /////////////////////////////////////////////////////////////// + // data members + protected: + + size_t mNInputs; ///< Size of input vectors + size_t mNOutputs; ///< Size of output vectors + + const Matrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component + const Matrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component + + Matrix<BaseFloat> mOutput; ///< outputs are OWNED by component + Matrix<BaseFloat> mErrorOutput; ///< outputs are OWNED by component + + }; + + + /** + * Class UpdatableComponent is a box which has some + * parameters adjustable by learning + * + * you can set the learning rate, lock the params, + * and learn from each data observation + */ + class UpdatableComponent : public Component + { + ////////////////////////////////////////////////////////////// + // Constructor & Destructor + public: + UpdatableComponent(size_t nInputs, size_t nOutputs, Component *pPred); + virtual ~UpdatableComponent(); + + + ////////////////////////////////////////////////////////////// + // Interface specification (public) + public: + /// + virtual bool IsUpdatable() const + { return true; } + + /// calculate gradient + virtual void Gradient() = 0; + /// accumulate gradient from other components + virtual void AccuGradient(const UpdatableComponent& src, int thr, int thrN) = 0; + /// update weights, reset the accumulator + virtual void Update(int thr, int thrN) = 0; + + /// Sets the learning rate of gradient descent + void LearnRate(BaseFloat rate); + /// Gets the learning rate of gradient descent + BaseFloat LearnRate() const; + + void Momentum(BaseFloat mmt); + BaseFloat Momentum() const ; + + void Weightcost(BaseFloat cost); + BaseFloat Weightcost() const; + + void Bunchsize(size_t size); + size_t Bunchsize() const; + + protected: + BaseFloat mLearningRate; + BaseFloat mMomentum; + BaseFloat mWeightcost; + size_t mBunchsize; + }; + + + + + ////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // Component:: + inline + Component:: + Component(size_t nInputs, size_t nOutputs, Component *pPred) + : mNInputs(nInputs), mNOutputs(nOutputs), + mpInput(NULL), mpErrorInput(NULL), + mOutput(), mErrorOutput() + { + /* DOUBLE LINK the Components */ + if (pPred != NULL) { + SetInput(pPred->GetOutput()); + pPred->SetErrorInput(GetErrorOutput()); + } + } + + + inline + Component:: + ~Component() + { + ; + } + + inline void + Component:: + Propagate() + { + //initialize output buffer + if(mOutput.Rows() != GetInput().Rows() || mOutput.Cols() != GetNOutputs()) { + mOutput.Init(GetInput().Rows(),GetNOutputs()); + } + //do the dimensionality test + if(GetNInputs() != GetInput().Cols()) { + KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs() + << " Data dim: " << GetInput().Cols(); + } + //run transform + PropagateFnc(GetInput(),mOutput); + + } + + + inline void + Component:: + Backpropagate() + { + //re-initialize the output buffer + if(mErrorOutput.Rows() != GetErrorInput().Rows() || mErrorOutput.Cols() != GetNInputs()) { + mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs()); + } + + //do the dimensionality test + assert(GetErrorInput().Cols() == mNOutputs); + assert(mErrorOutput.Cols() == mNInputs); + assert(mErrorOutput.Rows() == GetErrorInput().Rows()); + + //transform + BackpropagateFnc(GetErrorInput(),mErrorOutput); + + } + + + inline void + Component:: + SetInput(const Matrix<BaseFloat>& rInput) + { + mpInput = &rInput; + } + + + inline void + Component:: + SetErrorInput(const Matrix<BaseFloat>& rErrorInput) + { + mpErrorInput = &rErrorInput; + } + + + inline const Matrix<BaseFloat>& + Component:: + GetInput() const + { + if (NULL == mpInput) Error("mpInput is NULL"); + return *mpInput; + } + + inline const Matrix<BaseFloat>& + Component:: + GetOutput() const + { + return mOutput; + } + + inline const Matrix<BaseFloat>& + Component:: + GetErrorInput() const + { + if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); + return *mpErrorInput; + } + + inline const Matrix<BaseFloat>& + Component:: + GetErrorOutput() const + { + return mErrorOutput; + } + + inline size_t + Component:: + GetNInputs() const + { + return mNInputs; + } + + inline size_t + Component:: + GetNOutputs() const + { + return mNOutputs; + } + + + + ////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // UpdatableComponent:: + + inline + UpdatableComponent:: + UpdatableComponent(size_t nInputs, size_t nOutputs, Component *pPred) + : Component(nInputs, nOutputs, pPred), + mLearningRate(0.0), mMomentum(0.0), mWeightcost(0.0), mBunchsize(0) + { + ; + } + + + inline + UpdatableComponent:: + ~UpdatableComponent() + { + ; + } + + + inline void + UpdatableComponent:: + LearnRate(BaseFloat rate) + { + mLearningRate = rate; + } + + inline BaseFloat + UpdatableComponent:: + LearnRate() const + { + return mLearningRate; + } + + + inline void + UpdatableComponent:: + Momentum(BaseFloat mmt) + { + mMomentum = mmt; + } + + inline BaseFloat + UpdatableComponent:: + Momentum() const + { + return mMomentum; + } + + + inline void + UpdatableComponent:: + Weightcost(BaseFloat cost) + { + mWeightcost = cost; + } + + inline BaseFloat + UpdatableComponent:: + Weightcost() const + { + return mWeightcost; + } + + + inline void + UpdatableComponent:: + Bunchsize(size_t size) + { + mBunchsize = size; + } + + inline size_t + UpdatableComponent:: + Bunchsize() const + { + return mBunchsize; + } + + +} // namespace TNet + + +#endif |