diff options
Diffstat (limited to 'src/CuTNetLib/cuMisc.h')
-rw-r--r-- | src/CuTNetLib/cuMisc.h | 74 |
1 files changed, 56 insertions, 18 deletions
diff --git a/src/CuTNetLib/cuMisc.h b/src/CuTNetLib/cuMisc.h index 8831622..10418e5 100644 --- a/src/CuTNetLib/cuMisc.h +++ b/src/CuTNetLib/cuMisc.h @@ -45,7 +45,7 @@ namespace TNet { if (NULL == mpInput) Error("mpInput is NULL"); mOutput.Init(*mpInput); } - void BackPropagate() + void Backpropagate() { if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); mErrorOutput.Init(*mpErrorInput); @@ -232,7 +232,6 @@ namespace TNet { return size; } - /// IO Data getters const CuMatrix<BaseFloat>& GetInput(int pos=0) { if (pos>=0 && pos<size) @@ -317,6 +316,21 @@ namespace TNet { rOut<<std::endl; } + const CuMatrix<BaseFloat>& GetErrorInput(int pos=0) + { + if (pos>=0 && pos<size) + return *ErrInputVec[pos]; + return *ErrInputVec[0]; + } + + void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0) + { + if (pos==0) + mpErrorInput=&rErrorInput; + if (pos>=0 && pos<size) + ErrInputVec[pos]=&rErrorInput; + } + void Propagate() { if (NULL == mpInput) Error("mpInput is NULL"); @@ -327,21 +341,25 @@ namespace TNet { loc+=SectLen[i]; } } - - protected: - - void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) - { Error("__func__ Nonsense"); } - void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + void Backpropagate() { int loc=0; + mErrorOutput.Init(GetErrorInput.Rows(),GetNInput()); for (int i=0;i<size;++i) { - Y.CopyCols(SectLen[i], 0, X, loc); - loc+=SectLen[i]; + mErrorOutput.CopyCols(SectLen[i], 0, GetErrorInput(i), loc); + loc += SectLen[i]; } } + + protected: + + void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { Error("__func__ Nonsense"); } + + void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { Error("__func__ Nonsense"); } int size; MatrixPtrVec OutputVec; @@ -400,28 +418,48 @@ namespace TNet { rOut<<std::endl; } - void Backpropagate() + const CuMatrix<BaseFloat>& GetInput(int pos=0) + { + if (pos>=0 && pos<size) + return *InputVec[pos]; + return *InputVec[0]; + } + + /// Set input vector (bind with the preceding NetworkComponent) + void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0) + { + if (pos==0) + mpInput=&rInput; + if (pos>=0 && pos<size) + InputVec[pos]=&rInput; + } + + void Propagate() { - if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); int loc=0; + mOutput.Init(GetInput.Rows(),GetNOutput()); for (int i=0;i<size;++i) { - ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]); - loc+=SectLen[i]; + mOutput.CopyCols(SectLen[i], 0, GetInput(i), loc); + loc += SectLen[i]; } } - - protected: - void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + void Backpropagate() { + if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); int loc=0; for (int i=0;i<size;++i) { - Y.CopyCols(SectLen[i], 0, X, loc); + ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]); loc+=SectLen[i]; } } + + protected: + + void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { Error("__func__ Nonsense"); } void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { Error("__func__ Nonsense"); } |