diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/CuTNetLib/cuComponent.h | 147 |
1 files changed, 45 insertions, 102 deletions
diff --git a/src/CuTNetLib/cuComponent.h b/src/CuTNetLib/cuComponent.h index 364c828..a43680e 100644 --- a/src/CuTNetLib/cuComponent.h +++ b/src/CuTNetLib/cuComponent.h @@ -119,27 +119,62 @@ namespace TNet { void SetNext(CuComponent* pNxt); /// Return the number of different inputs for complex component - virtual int GetInSect(); + virtual int GetInSect(){return 1;} /// Return the number of different outputs for complex component - virtual int GetOutSect(); + virtual int GetOutSect(){return 1;} /// IO Data getters - virtual const CuMatrix<BaseFloat>& GetInput(int pos=0); - virtual const CuMatrix<BaseFloat>& GetOutput(int pos=0); - virtual const CuMatrix<BaseFloat>& GetErrorInput(int pos=0); - virtual const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0); + virtual const CuMatrix<BaseFloat>& GetInput(int pos=0) { + if (NULL == mpInput) Error("mpInput is NULL"); + return *mpInput; + } + virtual const CuMatrix<BaseFloat>& GetOutput(int pos=0) { + return mOutput; + } + virtual const CuMatrix<BaseFloat>& GetErrorInput(int pos=0) { + if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); + return *mpErrorInput; + } + virtual const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0) { + return mErrorOutput; + } /// Set input vector (bind with the preceding NetworkComponent) - virtual void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0); + virtual void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0) { + mpInput = &rInput; + } /// Set error input vector (bind with the following NetworkComponent) - virtual void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0); + virtual void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0) { + mpErrorInput = &rErrorInput; + } /// Perform forward pass propagateion Input->Output, /// wrapper for the PropagateFnc method - virtual void Propagate(); + virtual void Propagate() { + //initialize output buffer + mOutput.Init(GetInput().Rows(),GetNOutputs()); + //do the dimensionality test + if(GetNInputs() != GetInput().Cols()) { + KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs() + << " Data dim: " << GetInput().Cols(); + } + //run transform + PropagateF(GetInput(),mOutput); + } /// Perform backward pass propagateion ErrorInput->ErrorOutput, /// wrapper for the BackpropagateFnc method - virtual void Backpropagate(); + virtual void Backpropagate() { + //re-initialize the output buffer + mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs()); + + //do the dimensionality test + assert(GetErrorInput().Cols() == mNOutputs); + assert(mErrorOutput.Cols() == mNInputs); + assert(mErrorOutput.Rows() == GetErrorInput().Rows()); + + //transform + BackpropagateF(GetErrorInput(),mErrorOutput); + } /// Reads the component parameters from stream virtual void ReadFromStream(std::istream& rIn) { } @@ -281,84 +316,6 @@ namespace TNet { { ; } - - inline void - CuComponent:: - Propagate() - { - //initialize output buffer - mOutput.Init(GetInput().Rows(),GetNOutputs()); - //do the dimensionality test - if(GetNInputs() != GetInput().Cols()) { - KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs() - << " Data dim: " << GetInput().Cols(); - } - //run transform - PropagateF(GetInput(),mOutput); - } - - - inline void - CuComponent:: - Backpropagate() - { - //re-initialize the output buffer - mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs()); - - //do the dimensionality test - assert(GetErrorInput().Cols() == mNOutputs); - assert(mErrorOutput.Cols() == mNInputs); - assert(mErrorOutput.Rows() == GetErrorInput().Rows()); - - //transform - BackpropagateF(GetErrorInput(),mErrorOutput); - } - - - inline void - CuComponent:: - SetInput(const CuMatrix<BaseFloat>& rInput,int pos) - { - mpInput = &rInput; - } - - - inline void - CuComponent:: - SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos) - { - mpErrorInput = &rErrorInput; - } - - const CuMatrix<BaseFloat>& - CuComponent:: - GetInput(int pos) - { - if (NULL == mpInput) Error("mpInput is NULL"); - return *mpInput; - } - - const CuMatrix<BaseFloat>& - CuComponent:: - GetOutput(int pos) - { - return mOutput; - } - - const CuMatrix<BaseFloat>& - CuComponent:: - GetErrorInput(int pos) - { - if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); - return *mpErrorInput; - } - - const CuMatrix<BaseFloat>& - CuComponent:: - GetErrorOutput(int pos) - { - return mErrorOutput; - } inline size_t CuComponent:: @@ -373,20 +330,6 @@ namespace TNet { { return mNOutputs; } - - inline int - CuComponent:: - GetInSect() - { - return 1; - } - - inline int - CuComponent:: - GetOutSect() - { - return 1; - } inline size_t CuComponent:: |