summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuComponent.h
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-10-06 15:05:44 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-10-06 15:05:44 +0800
commitf910ddf6ed09344e8cea840d81dd63521cd96b45 (patch)
tree3297f8ea81e05b5aa5363bb4009736d8970ff789 /src/CuTNetLib/cuComponent.h
parente246076a03c2496ca8ef814b174b5b741928a1b6 (diff)
downloadtnet-f910ddf6ed09344e8cea840d81dd63521cd96b45.tar.gz
tnet-f910ddf6ed09344e8cea840d81dd63521cd96b45.tar.bz2
tnet-f910ddf6ed09344e8cea840d81dd63521cd96b45.zip
Virtual 5
Diffstat (limited to 'src/CuTNetLib/cuComponent.h')
-rw-r--r--src/CuTNetLib/cuComponent.h147
1 files changed, 45 insertions, 102 deletions
diff --git a/src/CuTNetLib/cuComponent.h b/src/CuTNetLib/cuComponent.h
index 364c828..a43680e 100644
--- a/src/CuTNetLib/cuComponent.h
+++ b/src/CuTNetLib/cuComponent.h
@@ -119,27 +119,62 @@ namespace TNet {
void SetNext(CuComponent* pNxt);
/// Return the number of different inputs for complex component
- virtual int GetInSect();
+ virtual int GetInSect(){return 1;}
/// Return the number of different outputs for complex component
- virtual int GetOutSect();
+ virtual int GetOutSect(){return 1;}
/// IO Data getters
- virtual const CuMatrix<BaseFloat>& GetInput(int pos=0);
- virtual const CuMatrix<BaseFloat>& GetOutput(int pos=0);
- virtual const CuMatrix<BaseFloat>& GetErrorInput(int pos=0);
- virtual const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0);
+ virtual const CuMatrix<BaseFloat>& GetInput(int pos=0) {
+ if (NULL == mpInput) Error("mpInput is NULL");
+ return *mpInput;
+ }
+ virtual const CuMatrix<BaseFloat>& GetOutput(int pos=0) {
+ return mOutput;
+ }
+ virtual const CuMatrix<BaseFloat>& GetErrorInput(int pos=0) {
+ if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
+ return *mpErrorInput;
+ }
+ virtual const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0) {
+ return mErrorOutput;
+ }
/// Set input vector (bind with the preceding NetworkComponent)
- virtual void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0);
+ virtual void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0) {
+ mpInput = &rInput;
+ }
/// Set error input vector (bind with the following NetworkComponent)
- virtual void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0);
+ virtual void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0) {
+ mpErrorInput = &rErrorInput;
+ }
/// Perform forward pass propagateion Input->Output,
/// wrapper for the PropagateFnc method
- virtual void Propagate();
+ virtual void Propagate() {
+ //initialize output buffer
+ mOutput.Init(GetInput().Rows(),GetNOutputs());
+ //do the dimensionality test
+ if(GetNInputs() != GetInput().Cols()) {
+ KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs()
+ << " Data dim: " << GetInput().Cols();
+ }
+ //run transform
+ PropagateF(GetInput(),mOutput);
+ }
/// Perform backward pass propagateion ErrorInput->ErrorOutput,
/// wrapper for the BackpropagateFnc method
- virtual void Backpropagate();
+ virtual void Backpropagate() {
+ //re-initialize the output buffer
+ mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs());
+
+ //do the dimensionality test
+ assert(GetErrorInput().Cols() == mNOutputs);
+ assert(mErrorOutput.Cols() == mNInputs);
+ assert(mErrorOutput.Rows() == GetErrorInput().Rows());
+
+ //transform
+ BackpropagateF(GetErrorInput(),mErrorOutput);
+ }
/// Reads the component parameters from stream
virtual void ReadFromStream(std::istream& rIn) { }
@@ -281,84 +316,6 @@ namespace TNet {
{
;
}
-
- inline void
- CuComponent::
- Propagate()
- {
- //initialize output buffer
- mOutput.Init(GetInput().Rows(),GetNOutputs());
- //do the dimensionality test
- if(GetNInputs() != GetInput().Cols()) {
- KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs()
- << " Data dim: " << GetInput().Cols();
- }
- //run transform
- PropagateF(GetInput(),mOutput);
- }
-
-
- inline void
- CuComponent::
- Backpropagate()
- {
- //re-initialize the output buffer
- mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs());
-
- //do the dimensionality test
- assert(GetErrorInput().Cols() == mNOutputs);
- assert(mErrorOutput.Cols() == mNInputs);
- assert(mErrorOutput.Rows() == GetErrorInput().Rows());
-
- //transform
- BackpropagateF(GetErrorInput(),mErrorOutput);
- }
-
-
- inline void
- CuComponent::
- SetInput(const CuMatrix<BaseFloat>& rInput,int pos)
- {
- mpInput = &rInput;
- }
-
-
- inline void
- CuComponent::
- SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos)
- {
- mpErrorInput = &rErrorInput;
- }
-
- const CuMatrix<BaseFloat>&
- CuComponent::
- GetInput(int pos)
- {
- if (NULL == mpInput) Error("mpInput is NULL");
- return *mpInput;
- }
-
- const CuMatrix<BaseFloat>&
- CuComponent::
- GetOutput(int pos)
- {
- return mOutput;
- }
-
- const CuMatrix<BaseFloat>&
- CuComponent::
- GetErrorInput(int pos)
- {
- if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
- return *mpErrorInput;
- }
-
- const CuMatrix<BaseFloat>&
- CuComponent::
- GetErrorOutput(int pos)
- {
- return mErrorOutput;
- }
inline size_t
CuComponent::
@@ -373,20 +330,6 @@ namespace TNet {
{
return mNOutputs;
}
-
- inline int
- CuComponent::
- GetInSect()
- {
- return 1;
- }
-
- inline int
- CuComponent::
- GetOutSect()
- {
- return 1;
- }
inline size_t
CuComponent::