From 2689091c8749ec8f2d8099a2c43c7a1fbeecbdf8 Mon Sep 17 00:00:00 2001
From: Joe Zhao <ztuowen@gmail.com>
Date: Wed, 8 Oct 2014 16:20:55 +0800
Subject: add getin & get out

---
 src/CuTNetLib/cuMisc.h | 74 ++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 56 insertions(+), 18 deletions(-)

(limited to 'src')

diff --git a/src/CuTNetLib/cuMisc.h b/src/CuTNetLib/cuMisc.h
index 8831622..10418e5 100644
--- a/src/CuTNetLib/cuMisc.h
+++ b/src/CuTNetLib/cuMisc.h
@@ -45,7 +45,7 @@ namespace TNet {
       if (NULL == mpInput) Error("mpInput is NULL");
       mOutput.Init(*mpInput);
     }
-    void BackPropagate()
+    void Backpropagate()
     {
       if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
       mErrorOutput.Init(*mpErrorInput);
@@ -232,7 +232,6 @@ namespace TNet {
       return size;
     }
     
-    /// IO Data getters
     const CuMatrix<BaseFloat>& GetInput(int pos=0)
     {
       if (pos>=0 && pos<size)
@@ -317,6 +316,21 @@ namespace TNet {
       rOut<<std::endl;
     }
     
+    const CuMatrix<BaseFloat>& GetErrorInput(int pos=0)
+    {
+      if (pos>=0 && pos<size)
+        return *ErrInputVec[pos];
+      return *ErrInputVec[0];
+    }
+
+    void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0)
+    {
+      if (pos==0)
+        mpErrorInput=&rErrorInput;
+      if (pos>=0 && pos<size)
+        ErrInputVec[pos]=&rErrorInput;
+    }
+
     void Propagate()
     {
       if (NULL == mpInput) Error("mpInput is NULL");
@@ -327,21 +341,25 @@ namespace TNet {
         loc+=SectLen[i];
       }
     }
-     
-   protected:
-    
-    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
-    { Error("__func__ Nonsense"); }
 
-    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+    void Backpropagate()
     {
       int loc=0;
+      mErrorOutput.Init(GetErrorInput.Rows(),GetNInput());
       for (int i=0;i<size;++i)
       {
-        Y.CopyCols(SectLen[i], 0, X, loc);
-        loc+=SectLen[i];
+        mErrorOutput.CopyCols(SectLen[i], 0, GetErrorInput(i), loc);
+        loc += SectLen[i];
       }
     }
+     
+   protected:
+    
+    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+    { Error("__func__ Nonsense"); }
+
+    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+    { Error("__func__ Nonsense"); }
     
     int size;
     MatrixPtrVec OutputVec;
@@ -400,28 +418,48 @@ namespace TNet {
       rOut<<std::endl;
     }
     
-    void Backpropagate()
+    const CuMatrix<BaseFloat>& GetInput(int pos=0)
+    {
+      if (pos>=0 && pos<size)
+        return *InputVec[pos];
+      return *InputVec[0];
+    }
+
+    /// Set input vector (bind with the preceding NetworkComponent)
+    void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0)
+    {
+      if (pos==0)
+        mpInput=&rInput;
+      if (pos>=0 && pos<size)
+        InputVec[pos]=&rInput;
+    }
+
+    void Propagate()
     {
-      if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
       int loc=0;
+      mOutput.Init(GetInput.Rows(),GetNOutput());
       for (int i=0;i<size;++i)
       {
-        ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]);
-        loc+=SectLen[i];
+        mOutput.CopyCols(SectLen[i], 0, GetInput(i), loc);
+        loc += SectLen[i];
       }
     }
-     
-   protected:
     
-    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+    void Backpropagate()
     {
+      if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
       int loc=0;
       for (int i=0;i<size;++i)
       {
-        Y.CopyCols(SectLen[i], 0, X, loc);
+        ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]);
         loc+=SectLen[i];
       }
     }
+     
+   protected:
+    
+    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+    { Error("__func__ Nonsense"); }
 
     void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
     { Error("__func__ Nonsense"); }
-- 
cgit v1.2.3-70-g09d2