diff options
Diffstat (limited to 'src/TNetLib/.svn/text-base/Component.h.svn-base')
| -rw-r--r-- | src/TNetLib/.svn/text-base/Component.h.svn-base | 387 | 
1 files changed, 387 insertions, 0 deletions
| diff --git a/src/TNetLib/.svn/text-base/Component.h.svn-base b/src/TNetLib/.svn/text-base/Component.h.svn-base new file mode 100644 index 0000000..762451e --- /dev/null +++ b/src/TNetLib/.svn/text-base/Component.h.svn-base @@ -0,0 +1,387 @@ +#ifndef _NETWORK_COMPONENT_I_H +#define _NETWORK_COMPONENT_I_H + + +#include "Vector.h" +#include "Matrix.h" + +#include <iostream> +#include <stdexcept> + + +namespace TNet { + +     +  /** +   * Basic element of the network, +   * it is a box with defined inputs and outputs,  +   * and functions to refresh outputs +   * +   * it is able to compute tranformation function (forward pass)  +   * and jacobian function (backward pass),  +   * which is to be implemented in descendents +   */  +  class Component  +  { +    public: +    /// Types of the net components +    typedef enum {  +      UPDATABLE_COMPONENT = 0x0100,  +      BIASED_LINEARITY, +      SHARED_LINEARITY, + +      ACT_FUN = 0x0200,  +      SOFTMAX, +      SIGMOID, +      BLOCK_SOFTMAX,  + +      OTHER = 0x0400, +      EXPAND, +      COPY, +      TRANSPOSE, +      BLOCK_LINEARITY, +      WINDOW, +      BIAS, +      LOG, +       +      BLOCK_ARRAY, +    } ComponentType; + + +    ////////////////////////////////////////////////////////////// +    // Constructor & Destructor +    public:  +      Component(size_t nInputs, size_t nOutputs, Component *pPred);  +      virtual ~Component();   +        +    ////////////////////////////////////////////////////////////// +    // Interface specification (public) +    public: +      /// Get Type Identification of the component +      virtual ComponentType GetType() const = 0;   +      /// Get Type Label of the component +      virtual const char* GetName() const = 0; +      ///  +      virtual bool IsUpdatable() const  +      { return false; } +      /// Clone the component +      virtual Component* Clone() const = 0;  + +      /// Get size of input vectors +      size_t GetNInputs() const;   +      /// Get size of output vectors  +      size_t GetNOutputs() const;  +      +      /// IO Data getters +      const Matrix<BaseFloat>& GetInput() const;  +      const Matrix<BaseFloat>& GetOutput() const; +      const Matrix<BaseFloat>& GetErrorInput() const; +      const Matrix<BaseFloat>& GetErrorOutput() const; +       +      /// Set input vector (bind with the preceding NetworkComponent) +      void SetInput(const Matrix<BaseFloat>& rInput);            +      /// Set error input vector (bind with the following NetworkComponent)  +      void SetErrorInput(const Matrix<BaseFloat>& rErrorInput);   +        +      /// Perform forward pass propagateion Input->Output +      void Propagate();  +      /// Perform backward pass propagateion ErrorInput->ErrorOutput +      void Backpropagate();  +  +      /// Reads the component parameters from stream +      virtual void ReadFromStream(std::istream& rIn)  { } +      /// Writes the components parameters to stream +      virtual void WriteToStream(std::ostream& rOut)  { }  + + +    /////////////////////////////////////////////////////////////// +    // Nonpublic member functions used to update data outputs  +    protected: +      /// Forward pass transformation (to be implemented by descendents...) +      virtual void PropagateFnc(const Matrix<BaseFloat>& X, Matrix<BaseFloat>& Y) = 0; +      /// Backward pass transformation (to be implemented by descendents...) +      virtual void BackpropagateFnc(const Matrix<BaseFloat>& X, Matrix<BaseFloat>& Y) = 0; + +    +    /////////////////////////////////////////////////////////////// +    // data members +    protected: + +      size_t mNInputs;  ///< Size of input vectors +      size_t mNOutputs; ///< Size of output vectors  +       +      const Matrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component +      const Matrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component + +      Matrix<BaseFloat> mOutput; ///< outputs are OWNED by component +      Matrix<BaseFloat> mErrorOutput; ///< outputs are OWNED by component + +  }; + + +  /** +   * Class UpdatableComponent is a box which has some  +   * parameters adjustable by learning +   * +   * you can set the learning rate, lock the params, +   * and learn from each data observation +   */ +  class UpdatableComponent : public Component +  { +    ////////////////////////////////////////////////////////////// +    // Constructor & Destructor +    public:  +      UpdatableComponent(size_t nInputs, size_t nOutputs, Component *pPred);  +      virtual ~UpdatableComponent(); + + +    ////////////////////////////////////////////////////////////// +    // Interface specification (public) +    public: +      /// +      virtual bool IsUpdatable() const  +      { return true; } + +      /// calculate gradient +      virtual void Gradient() = 0; +      /// accumulate gradient from other components +      virtual void AccuGradient(const UpdatableComponent& src, int thr, int thrN) = 0;   +      /// update weights, reset the accumulator +      virtual void Update(int thr, int thrN) = 0; + +      /// Sets the learning rate of gradient descent +      void LearnRate(BaseFloat rate); +      /// Gets the learning rate of gradient descent +      BaseFloat LearnRate() const; + +      void Momentum(BaseFloat mmt); +      BaseFloat Momentum() const ; + +      void Weightcost(BaseFloat cost); +      BaseFloat Weightcost() const; + +      void Bunchsize(size_t size); +      size_t Bunchsize() const; + +    protected: +      BaseFloat mLearningRate; +      BaseFloat mMomentum; +      BaseFloat mWeightcost; +      size_t mBunchsize; +  }; + + + + +  ////////////////////////////////////////////////////////////////////////// +  // INLINE FUNCTIONS  +  // Component:: +  inline +  Component:: +  Component(size_t nInputs, size_t nOutputs, Component *pPred)  +    : mNInputs(nInputs), mNOutputs(nOutputs),  +      mpInput(NULL), mpErrorInput(NULL),  +      mOutput(), mErrorOutput() +  {  +    /* DOUBLE LINK the Components */ +    if (pPred != NULL) { +      SetInput(pPred->GetOutput()); +      pPred->SetErrorInput(GetErrorOutput()); +    } +  }  + + +  inline +  Component:: +  ~Component() +  { +    ; +  } + +  inline void +  Component:: +  Propagate() +  { +    //initialize output buffer +    if(mOutput.Rows() != GetInput().Rows() || mOutput.Cols() != GetNOutputs()) { +      mOutput.Init(GetInput().Rows(),GetNOutputs()); +    } +    //do the dimensionality test +    if(GetNInputs() != GetInput().Cols()) { +      KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs()  +                << " Data dim: " << GetInput().Cols(); +    } +    //run transform +    PropagateFnc(GetInput(),mOutput); +   +  } + + +  inline void +  Component:: +  Backpropagate() +  { +    //re-initialize the output buffer +    if(mErrorOutput.Rows() != GetErrorInput().Rows() || mErrorOutput.Cols() != GetNInputs()) { +      mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs()); +    } + +    //do the dimensionality test +    assert(GetErrorInput().Cols() == mNOutputs); +    assert(mErrorOutput.Cols() == mNInputs); +    assert(mErrorOutput.Rows() == GetErrorInput().Rows()); + +    //transform +    BackpropagateFnc(GetErrorInput(),mErrorOutput); + + } + + +  inline void +  Component:: +  SetInput(const Matrix<BaseFloat>& rInput) +  { +    mpInput = &rInput; +  } + + +  inline void +  Component:: +  SetErrorInput(const Matrix<BaseFloat>& rErrorInput) +  { +    mpErrorInput = &rErrorInput; +  } + + +  inline const Matrix<BaseFloat>& +  Component:: +  GetInput() const +  { +    if (NULL == mpInput) Error("mpInput is NULL"); +    return *mpInput; +  } + +  inline const Matrix<BaseFloat>& +  Component:: +  GetOutput() const +  { +    return mOutput; +  } + +  inline const Matrix<BaseFloat>& +  Component:: +  GetErrorInput() const +  { +    if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); +    return *mpErrorInput; +  } + +  inline const Matrix<BaseFloat>& +  Component:: +  GetErrorOutput() const +  { +    return mErrorOutput; +  } + +  inline size_t +  Component:: +  GetNInputs() const +  { +    return mNInputs; +  } + +  inline size_t +  Component:: +  GetNOutputs() const +  { +    return mNOutputs; +  } + + + +  ////////////////////////////////////////////////////////////////////////// +  // INLINE FUNCTIONS  +  // UpdatableComponent:: +   +  inline  +  UpdatableComponent:: +  UpdatableComponent(size_t nInputs, size_t nOutputs, Component *pPred)  +    : Component(nInputs, nOutputs, pPred),  +      mLearningRate(0.0), mMomentum(0.0), mWeightcost(0.0), mBunchsize(0) +  { +    ;  +  }  + + +  inline +  UpdatableComponent:: +  ~UpdatableComponent() +  { +    ; +  } + + +  inline void +  UpdatableComponent:: +  LearnRate(BaseFloat rate) +  { +    mLearningRate = rate; +  } + +  inline BaseFloat +  UpdatableComponent:: +  LearnRate() const +  { +    return mLearningRate; +  } + + +  inline void +  UpdatableComponent:: +  Momentum(BaseFloat mmt) +  { +    mMomentum = mmt; +  } + +  inline BaseFloat +  UpdatableComponent:: +  Momentum() const +  { +    return mMomentum; +  } +   +   +  inline void +  UpdatableComponent:: +  Weightcost(BaseFloat cost) +  { +    mWeightcost = cost; +  } + +  inline BaseFloat +  UpdatableComponent:: +  Weightcost() const +  { +    return mWeightcost; +  } + +   +  inline void +  UpdatableComponent:: +  Bunchsize(size_t size) +  { +    mBunchsize = size; +  } + +  inline size_t +  UpdatableComponent:: +  Bunchsize() const +  { +    return mBunchsize; +  } + + +} // namespace TNet + + +#endif | 
