summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuComponent.h
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
commitcccccbf6cca94a3eaf813b4468453160e91c332b (patch)
tree23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuComponent.h
downloadtnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip
First commit
Diffstat (limited to 'src/CuTNetLib/cuComponent.h')
-rw-r--r--src/CuTNetLib/cuComponent.h505
1 files changed, 505 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuComponent.h b/src/CuTNetLib/cuComponent.h
new file mode 100644
index 0000000..6cc8462
--- /dev/null
+++ b/src/CuTNetLib/cuComponent.h
@@ -0,0 +1,505 @@
+#ifndef _CUNETWORK_COMPONENT_I_H
+#define _CUNETWORK_COMPONENT_I_H
+
+
+#include "Vector.h"
+#include "Matrix.h"
+#include "Error.h"
+
+#include "cumatrix.h"
+
+#include <iostream>
+#include <stdexcept>
+#include <vector>
+
+/// \defgroup CuNNLayer CuNN Layer types
+/// \ingroup CuNNComp
+
+/// \defgroup CuNNUpdatable CuNN Updatable Layer
+/// \ingroup CuNNLayer
+
+/// \defgroup CuNNActivation CuNN Activation Func Layer
+/// \ingroup CuNNLayer
+
+/// \defgroup CuNNMisc CuNN Misc Layer
+/// \ingroup CuNNLayer
+
+namespace TNet {
+
+
+ /**
+ * \brief Neural network building blocks
+ *
+ * Basic element of the network,
+ * it is a box with defined inputs and outputs,
+ * and functions to refresh outputs
+ *
+ * it is able to compute tranformation function (forward pass)
+ * and jacobian function (backward pass),
+ * which is to be implemented in descendents
+ *
+ * Compoments together form a doubly-linked list.
+ *
+ * Every input rows are a frame of data,
+ * and the input size are the number of columns
+ */
+ class CuComponent
+ {
+ public:
+ /// Types of the net components
+ typedef enum {
+ UPDATABLE_COMPONENT = 0x0100,
+ BIASED_LINEARITY,
+ DISCRETE_LINEARITY,
+ SHARED_LINEARITY,
+ SPARSE_LINEARITY,
+ RBM,
+ RBM_SPARSE,
+ RECURRENT,
+ LINEARITY,
+ UPDATABLEBIAS,
+ DISCRETE,
+ COMPOUND,
+
+ ACT_FUN = 0x0200,
+ SOFTMAX,
+ SIGMOID,
+
+ OTHER = 0x0400,
+ EXPAND,
+ COPY,
+ TRANSPOSE,
+ BLOCK_LINEARITY,
+ WINDOW,
+ BIAS,
+ LOG,
+ PIPE,
+ LEARNSTOP,
+ DISTRIB,
+ COMBINE,
+ DIVIDE,
+ MERGE,
+ REORDER,
+
+ BLOCK_ARRAY,
+ } ComponentType;
+
+ typedef std::vector< CuMatrix<BaseFloat>* > MatrixPtrVec;
+
+ //////////////////////////////////////////////////////////////
+ // Constructor & Destructor
+ public:
+ CuComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred);
+ virtual ~CuComponent();
+
+ //////////////////////////////////////////////////////////////
+ // Interface specification (public)
+ public:
+ /// Get Type Identification of the component
+ virtual ComponentType GetType() const = 0;
+ /// Get Type Label of the component
+ virtual const char* GetName() const = 0;
+ /// Return if the component is UpdatableComponent
+ virtual bool IsUpdatable() const
+ { return false; }
+
+ /// Get size of input vectors
+ size_t GetNInputs() const;
+ /// Get size of output vectors
+ size_t GetNOutputs() const;
+
+ /// Set size of input vectors
+ size_t SetNInputs(size_t nInputs);
+ /// Set size of output vectors
+ size_t SetNOutputs(size_t nOutputs);
+ /// Set the previous component
+ void SetPrevious(CuComponent* pPred);
+ /// Set the next component
+ void SetNext(CuComponent* pNxt);
+
+ /// Return the number of different inputs for complex component
+ int GetInSect();
+ /// Return the number of different outputs for complex component
+ int GetOutSect();
+
+ /// IO Data getters
+ CuMatrix<BaseFloat>& GetInput(int pos=0);
+ CuMatrix<BaseFloat>& GetOutput(int pos=0);
+ CuMatrix<BaseFloat>& GetErrorInput(int pos=0);
+ CuMatrix<BaseFloat>& GetErrorOutput(int pos=0);
+
+ /// Set input vector (bind with the preceding NetworkComponent)
+ void SetInput(CuMatrix<BaseFloat>& rInput,int pos=0);
+ /// Set error input vector (bind with the following NetworkComponent)
+ void SetErrorInput(CuMatrix<BaseFloat>& rErrorInput,int pos=0);
+
+ /// Perform forward pass propagateion Input->Output,
+ /// wrapper for the PropagateFnc method
+ void Propagate();
+ /// Perform backward pass propagateion ErrorInput->ErrorOutput,
+ /// wrapper for the BackpropagateFnc method
+ void Backpropagate();
+
+ /// Reads the component parameters from stream
+ virtual void ReadFromStream(std::istream& rIn) { }
+ /// Writes the components parameters to stream
+ virtual void WriteToStream(std::ostream& rOut) { }
+
+ /// Public wrapper for PropagateFnc
+ void PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ /// Public wrapper for BackpropagateFnc
+ void BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+
+
+ ///////////////////////////////////////////////////////////////
+ // Nonpublic member functions used to update data outputs
+ protected:
+ /// Forward pass transformation (to be implemented by descendents...)
+ /// \param[in] X InputMatrix (Network input or Output from last layer)
+ /// \param[out] Y OutputMatrix (Network output or input of the next layer)
+ virtual void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) = 0;
+ /// Backward pass transformation (to be implemented by descendents...)
+ /// \param[in] X InputMatrix (Network Error, objective func output, or Error output from the next layer)
+ /// \param[out] Y OutputMatrix (Error input of the last layer)
+ virtual void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) = 0;
+
+
+ ///////////////////////////////////////////////////////////////
+ // data members
+ protected:
+
+ size_t mNInputs; ///< Size of input vectors
+ size_t mNOutputs; ///< Size of output vectors
+
+ CuMatrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component
+ CuMatrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component
+
+ CuMatrix<BaseFloat> mOutput; ///< outputs are OWNED by component
+ CuMatrix<BaseFloat> mErrorOutput; ///< outputs are OWNED by component
+
+ CuComponent* preComp;///< The preceding component in the Network
+ CuComponent* nxtComp;///< The following component in the Network
+ };
+
+
+ /**
+ * \brief Class UpdatableComponent is a box which has some
+ * parameters adjustable by learning
+ *
+ * you can set the learning rate, lock the params,
+ * and learn from each data observation
+ */
+ class CuUpdatableComponent : public CuComponent
+ {
+ //////////////////////////////////////////////////////////////
+ // Constructor & Destructor
+ public:
+ CuUpdatableComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred);
+ virtual ~CuUpdatableComponent();
+
+
+ //////////////////////////////////////////////////////////////
+ // Interface specification (public)
+ public:
+ /// Return if CuUpdatableComponent is updatable?
+ virtual bool IsUpdatable() const
+ { return true; }
+
+ /// get gradient and update the parameters in one step
+ virtual void Update() = 0;
+
+ /// Sets the learning rate of gradient descent
+ void LearnRate(BaseFloat rate);
+ /// Gets the learning rate of gradient descent
+ BaseFloat LearnRate();
+
+ /// Sets the momentum
+ void Momentum(BaseFloat mmt);
+ BaseFloat Momentum();
+
+ /// Set the weight decay rate to penalize large weights
+ void Weightcost(BaseFloat cost);
+ BaseFloat Weightcost();
+
+ /// Set whether gradient is divided by frames
+ void GradDivFrm(bool div);
+ bool GradDivFrm();
+
+ protected:
+ BaseFloat mLearningRate;
+ BaseFloat mMomentum;
+ BaseFloat mWeightcost;
+ bool mGradDivFrm;
+
+ };
+
+
+
+
+ //////////////////////////////////////////////////////////////////////////
+ // INLINE FUNCTIONS
+ // CuComponent::
+ inline
+ CuComponent::
+ CuComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred)
+ : mNInputs(nInputs), mNOutputs(nOutputs),
+ mpInput(NULL), mpErrorInput(NULL),
+ mOutput(), mErrorOutput(),preComp(pPred)
+ {
+ /* DOUBLE LINK the Components */
+ if (pPred != NULL) {
+ SetPrevious(pPred);
+ pPred->SetNext(this);
+ }
+ }
+
+ inline void
+ CuComponent::
+ SetPrevious(CuComponent* pPred)
+ {
+ preComp=pPred;
+ /* DOUBLE LINK the Components */
+ if (pPred != NULL) {
+ SetInput(pPred->GetOutput());
+ }
+ }
+
+ inline void
+ CuComponent::
+ SetNext(CuComponent* pNxt)
+ {
+ nxtComp=pNxt;
+ if (pNxt != NULL) {
+ SetErrorInput(pNxt->GetErrorOutput());
+ }
+ }
+
+ inline
+ CuComponent::
+ ~CuComponent()
+ {
+ ;
+ }
+
+ inline void
+ CuComponent::
+ Propagate()
+ {
+ //initialize output buffer
+ mOutput.Init(GetInput().Rows(),GetNOutputs());
+ //do the dimensionality test
+ if(GetNInputs() != GetInput().Cols()) {
+ KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs()
+ << " Data dim: " << GetInput().Cols();
+ }
+ //run transform
+ PropagateF(GetInput(),mOutput);
+ }
+
+
+ inline void
+ CuComponent::
+ Backpropagate()
+ {
+ //re-initialize the output buffer
+ mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs());
+
+ //do the dimensionality test
+ assert(GetErrorInput().Cols() == mNOutputs);
+ assert(mErrorOutput.Cols() == mNInputs);
+ assert(mErrorOutput.Rows() == GetErrorInput().Rows());
+
+ //transform
+ BackpropagateF(GetErrorInput(),mErrorOutput);
+ }
+
+
+ inline void
+ CuComponent::
+ SetInput(CuMatrix<BaseFloat>& rInput,int pos)
+ {
+ mpInput = &rInput;
+ }
+
+
+ inline void
+ CuComponent::
+ SetErrorInput(CuMatrix<BaseFloat>& rErrorInput,int pos)
+ {
+ mpErrorInput = &rErrorInput;
+ }
+
+ inline CuMatrix<BaseFloat>&
+ CuComponent::
+ GetInput(int pos)
+ {
+ if (NULL == mpInput) Error("mpInput is NULL");
+ return *mpInput;
+ }
+
+ inline CuMatrix<BaseFloat>&
+ CuComponent::
+ GetOutput(int pos)
+ {
+ return mOutput;
+ }
+
+ inline CuMatrix<BaseFloat>&
+ CuComponent::
+ GetErrorInput(int pos)
+ {
+ if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
+ return *mpErrorInput;
+ }
+
+ inline CuMatrix<BaseFloat>&
+ CuComponent::
+ GetErrorOutput(int pos)
+ {
+ return mErrorOutput;
+ }
+
+ inline size_t
+ CuComponent::
+ GetNInputs() const
+ {
+ return mNInputs;
+ }
+
+ inline size_t
+ CuComponent::
+ GetNOutputs() const
+ {
+ return mNOutputs;
+ }
+
+ inline int
+ CuComponent::
+ GetInSect()
+ {
+ return 1;
+ }
+
+ inline int
+ CuComponent::
+ GetOutSect()
+ {
+ return 1;
+ }
+
+ inline size_t
+ CuComponent::
+ SetNInputs(size_t nInputs)
+ {
+ mNInputs=nInputs;
+ }
+
+ inline size_t
+ CuComponent::
+ SetNOutputs(size_t nOutputs)
+ {
+ mNOutputs=nOutputs;
+ }
+
+ inline void
+ CuComponent::
+ PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ PropagateFnc(X,Y);
+ }
+ inline void
+ CuComponent::
+ BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ BackpropagateFnc(X,Y);
+ }
+
+
+ //////////////////////////////////////////////////////////////////////////
+ // INLINE FUNCTIONS
+ // UpdatableComponent::
+
+ inline
+ CuUpdatableComponent::
+ CuUpdatableComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred)
+ : CuComponent(nInputs, nOutputs, pPred),
+ mLearningRate(0.0), mMomentum(0), mWeightcost(0), mGradDivFrm(true)
+ {
+ ;
+ }
+
+
+ inline
+ CuUpdatableComponent::
+ ~CuUpdatableComponent()
+ {
+ ;
+ }
+
+
+ inline void
+ CuUpdatableComponent::
+ LearnRate(BaseFloat rate)
+ {
+ mLearningRate = rate;
+ }
+
+
+ inline BaseFloat
+ CuUpdatableComponent::
+ LearnRate()
+ {
+ return mLearningRate;
+ }
+
+
+ inline void
+ CuUpdatableComponent::
+ Momentum(BaseFloat mmt)
+ {
+ mMomentum = mmt;
+ }
+
+
+ inline BaseFloat
+ CuUpdatableComponent::
+ Momentum()
+ {
+ return mMomentum;
+ }
+
+
+ inline void
+ CuUpdatableComponent::
+ Weightcost(BaseFloat cost)
+ {
+ mWeightcost = cost;
+ }
+
+
+ inline BaseFloat
+ CuUpdatableComponent::
+ Weightcost()
+ {
+ return mWeightcost;
+ }
+
+
+ inline void
+ CuUpdatableComponent::
+ GradDivFrm(bool div)
+ {
+ mGradDivFrm = div;
+ }
+
+ inline bool
+ CuUpdatableComponent::
+ GradDivFrm()
+ {
+ return mGradDivFrm;
+ }
+
+} // namespace TNet
+
+
+#endif