From 2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8 Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Wed, 8 Oct 2014 16:20:55 +0800 Subject: add getin & get out --- src/CuTNetLib/cuMisc.h | 74 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 18 deletions(-) (limited to 'src/CuTNetLib') diff --git a/src/CuTNetLib/cuMisc.h b/src/CuTNetLib/cuMisc.h index 8831622..10418e5 100644 --- a/src/CuTNetLib/cuMisc.h +++ b/src/CuTNetLib/cuMisc.h @@ -45,7 +45,7 @@ namespace TNet { if (NULL == mpInput) Error("mpInput is NULL"); mOutput.Init(*mpInput); } - void BackPropagate() + void Backpropagate() { if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); mErrorOutput.Init(*mpErrorInput); @@ -232,7 +232,6 @@ namespace TNet { return size; } - /// IO Data getters const CuMatrix& GetInput(int pos=0) { if (pos>=0 && pos& GetErrorInput(int pos=0) + { + if (pos>=0 && pos& rErrorInput,int pos=0) + { + if (pos==0) + mpErrorInput=&rErrorInput; + if (pos>=0 && pos& X, CuMatrix& Y) - { Error("__func__ Nonsense"); } - void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) + void Backpropagate() { int loc=0; + mErrorOutput.Init(GetErrorInput.Rows(),GetNInput()); for (int i=0;i& X, CuMatrix& Y) + { Error("__func__ Nonsense"); } + + void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) + { Error("__func__ Nonsense"); } int size; MatrixPtrVec OutputVec; @@ -400,28 +418,48 @@ namespace TNet { rOut<& GetInput(int pos=0) + { + if (pos>=0 && pos& rInput,int pos=0) + { + if (pos==0) + mpInput=&rInput; + if (pos>=0 && posInit(*mpErrorInput,loc,SectLen[i]); - loc+=SectLen[i]; + mOutput.CopyCols(SectLen[i], 0, GetInput(i), loc); + loc += SectLen[i]; } } - - protected: - void PropagateFnc(const CuMatrix& X, CuMatrix& Y) + void Backpropagate() { + if (NULL == mpErrorInput) Error("mpErrorInput is NULL"); int loc=0; for (int i=0;iInit(*mpErrorInput,loc,SectLen[i]); loc+=SectLen[i]; } } + + protected: + + void PropagateFnc(const CuMatrix& X, CuMatrix& Y) + { Error("__func__ Nonsense"); } void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) { Error("__func__ Nonsense"); } -- cgit v1.2.3-70-g09d2