summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-10-08 16:20:55 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-10-08 16:20:55 +0800
commit2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8 (patch)
treedcff302fc209b088522b021288bbbd5151478754 /src
parent72cebebe44749bd11b6f8dd0f6a58a08f83238cf (diff)
downloadtnet-2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8.tar.gz
tnet-2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8.tar.bz2
tnet-2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8.zip
add getin & get out
Diffstat (limited to 'src')
-rw-r--r--src/CuTNetLib/cuMisc.h74
1 files changed, 56 insertions, 18 deletions
diff --git a/src/CuTNetLib/cuMisc.h b/src/CuTNetLib/cuMisc.h
index 8831622..10418e5 100644
--- a/src/CuTNetLib/cuMisc.h
+++ b/src/CuTNetLib/cuMisc.h
@@ -45,7 +45,7 @@ namespace TNet {
if (NULL == mpInput) Error("mpInput is NULL");
mOutput.Init(*mpInput);
}
- void BackPropagate()
+ void Backpropagate()
{
if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
mErrorOutput.Init(*mpErrorInput);
@@ -232,7 +232,6 @@ namespace TNet {
return size;
}
- /// IO Data getters
const CuMatrix<BaseFloat>& GetInput(int pos=0)
{
if (pos>=0 && pos<size)
@@ -317,6 +316,21 @@ namespace TNet {
rOut<<std::endl;
}
+ const CuMatrix<BaseFloat>& GetErrorInput(int pos=0)
+ {
+ if (pos>=0 && pos<size)
+ return *ErrInputVec[pos];
+ return *ErrInputVec[0];
+ }
+
+ void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0)
+ {
+ if (pos==0)
+ mpErrorInput=&rErrorInput;
+ if (pos>=0 && pos<size)
+ ErrInputVec[pos]=&rErrorInput;
+ }
+
void Propagate()
{
if (NULL == mpInput) Error("mpInput is NULL");
@@ -327,21 +341,25 @@ namespace TNet {
loc+=SectLen[i];
}
}
-
- protected:
-
- void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
- { Error("__func__ Nonsense"); }
- void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ void Backpropagate()
{
int loc=0;
+ mErrorOutput.Init(GetErrorInput.Rows(),GetNInput());
for (int i=0;i<size;++i)
{
- Y.CopyCols(SectLen[i], 0, X, loc);
- loc+=SectLen[i];
+ mErrorOutput.CopyCols(SectLen[i], 0, GetErrorInput(i), loc);
+ loc += SectLen[i];
}
}
+
+ protected:
+
+ void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ { Error("__func__ Nonsense"); }
+
+ void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ { Error("__func__ Nonsense"); }
int size;
MatrixPtrVec OutputVec;
@@ -400,28 +418,48 @@ namespace TNet {
rOut<<std::endl;
}
- void Backpropagate()
+ const CuMatrix<BaseFloat>& GetInput(int pos=0)
+ {
+ if (pos>=0 && pos<size)
+ return *InputVec[pos];
+ return *InputVec[0];
+ }
+
+ /// Set input vector (bind with the preceding NetworkComponent)
+ void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0)
+ {
+ if (pos==0)
+ mpInput=&rInput;
+ if (pos>=0 && pos<size)
+ InputVec[pos]=&rInput;
+ }
+
+ void Propagate()
{
- if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
int loc=0;
+ mOutput.Init(GetInput.Rows(),GetNOutput());
for (int i=0;i<size;++i)
{
- ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]);
- loc+=SectLen[i];
+ mOutput.CopyCols(SectLen[i], 0, GetInput(i), loc);
+ loc += SectLen[i];
}
}
-
- protected:
- void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ void Backpropagate()
{
+ if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
int loc=0;
for (int i=0;i<size;++i)
{
- Y.CopyCols(SectLen[i], 0, X, loc);
+ ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]);
loc+=SectLen[i];
}
}
+
+ protected:
+
+ void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ { Error("__func__ Nonsense"); }
void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
{ Error("__func__ Nonsense"); }