diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-10-08 16:20:55 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-10-08 16:20:55 +0800 |
commit | 2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8 (patch) | |
tree | dcff302fc209b088522b021288bbbd5151478754 /src | |
parent | 72cebebe44749bd11b6f8dd0f6a58a08f83238cf (diff) | |
download | tnet-2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8.tar.gz tnet-2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8.tar.bz2 tnet-2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8.zip |
add getin & get out
Diffstat (limited to 'src')
-rw-r--r-- | src/CuTNetLib/cuMisc.h | 74 |
1 files changed, 56 insertions, 18 deletions
diff --git a/src/CuTNetLib/cuMisc.h b/src/CuTNetLib/cuMisc.h index 8831622..10418e5 100644 --- a/src/CuTNetLib/cuMisc.h +++ b/src/CuTNetLib/cuMisc.h @@ -45,7 +45,7 @@ namespace TNet { if (NULL == mpInput) Error("mpInput is NULL"); mOutput.Init(*mpInput); } - void BackPropagate() + void Backpropagate() { if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); mErrorOutput.Init(*mpErrorInput); @@ -232,7 +232,6 @@ namespace TNet { return size; } - /// IO Data getters const CuMatrix<BaseFloat>& GetInput(int pos=0) { if (pos>=0 && pos<size) @@ -317,6 +316,21 @@ namespace TNet { rOut<<std::endl; } + const CuMatrix<BaseFloat>& GetErrorInput(int pos=0) + { + if (pos>=0 && pos<size) + return *ErrInputVec[pos]; + return *ErrInputVec[0]; + } + + void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0) + { + if (pos==0) + mpErrorInput=&rErrorInput; + if (pos>=0 && pos<size) + ErrInputVec[pos]=&rErrorInput; + } + void Propagate() { if (NULL == mpInput) Error("mpInput is NULL"); @@ -327,21 +341,25 @@ namespace TNet { loc+=SectLen[i]; } } - - protected: - - void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) - { Error("__func__ Nonsense"); } - void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + void Backpropagate() { int loc=0; + mErrorOutput.Init(GetErrorInput.Rows(),GetNInput()); for (int i=0;i<size;++i) { - Y.CopyCols(SectLen[i], 0, X, loc); - loc+=SectLen[i]; + mErrorOutput.CopyCols(SectLen[i], 0, GetErrorInput(i), loc); + loc += SectLen[i]; } } + + protected: + + void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { Error("__func__ Nonsense"); } + + void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { Error("__func__ Nonsense"); } int size; MatrixPtrVec OutputVec; @@ -400,28 +418,48 @@ namespace TNet { rOut<<std::endl; } - void Backpropagate() + const CuMatrix<BaseFloat>& GetInput(int pos=0) + { + if (pos>=0 && pos<size) + return *InputVec[pos]; + return *InputVec[0]; + } + + /// Set input vector (bind with the preceding NetworkComponent) + void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0) + { + if (pos==0) + mpInput=&rInput; + if (pos>=0 && pos<size) + InputVec[pos]=&rInput; + } + + void Propagate() { - if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); int loc=0; + mOutput.Init(GetInput.Rows(),GetNOutput()); for (int i=0;i<size;++i) { - ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]); - loc+=SectLen[i]; + mOutput.CopyCols(SectLen[i], 0, GetInput(i), loc); + loc += SectLen[i]; } } - - protected: - void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + void Backpropagate() { + if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); int loc=0; for (int i=0;i<size;++i) { - Y.CopyCols(SectLen[i], 0, X, loc); + ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]); loc+=SectLen[i]; } } + + protected: + + void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { Error("__func__ Nonsense"); } void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { Error("__func__ Nonsense"); } |