diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/CuTNetLib/cuMisc.h | 74 | 
1 files changed, 56 insertions, 18 deletions
| diff --git a/src/CuTNetLib/cuMisc.h b/src/CuTNetLib/cuMisc.h index 8831622..10418e5 100644 --- a/src/CuTNetLib/cuMisc.h +++ b/src/CuTNetLib/cuMisc.h @@ -45,7 +45,7 @@ namespace TNet {        if (NULL == mpInput) Error("mpInput is NULL");        mOutput.Init(*mpInput);      } -    void BackPropagate() +    void Backpropagate()      {        if (NULL == mpErrorInput) Error("mpErrorInput is NULL");        mErrorOutput.Init(*mpErrorInput); @@ -232,7 +232,6 @@ namespace TNet {        return size;      } -    /// IO Data getters      const CuMatrix<BaseFloat>& GetInput(int pos=0)      {        if (pos>=0 && pos<size) @@ -317,6 +316,21 @@ namespace TNet {        rOut<<std::endl;      } +    const CuMatrix<BaseFloat>& GetErrorInput(int pos=0) +    { +      if (pos>=0 && pos<size) +        return *ErrInputVec[pos]; +      return *ErrInputVec[0]; +    } + +    void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0) +    { +      if (pos==0) +        mpErrorInput=&rErrorInput; +      if (pos>=0 && pos<size) +        ErrInputVec[pos]=&rErrorInput; +    } +      void Propagate()      {        if (NULL == mpInput) Error("mpInput is NULL"); @@ -327,21 +341,25 @@ namespace TNet {          loc+=SectLen[i];        }      } -      -   protected: -     -    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) -    { Error("__func__ Nonsense"); } -    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) +    void Backpropagate()      {        int loc=0; +      mErrorOutput.Init(GetErrorInput.Rows(),GetNInput());        for (int i=0;i<size;++i)        { -        Y.CopyCols(SectLen[i], 0, X, loc); -        loc+=SectLen[i]; +        mErrorOutput.CopyCols(SectLen[i], 0, GetErrorInput(i), loc); +        loc += SectLen[i];        }      } +      +   protected: +     +    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) +    { Error("__func__ Nonsense"); } + +    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) +    { Error("__func__ Nonsense"); }      int size;      MatrixPtrVec OutputVec; @@ -400,28 +418,48 @@ namespace TNet {        rOut<<std::endl;      } -    void Backpropagate() +    const CuMatrix<BaseFloat>& GetInput(int pos=0) +    { +      if (pos>=0 && pos<size) +        return *InputVec[pos]; +      return *InputVec[0]; +    } + +    /// Set input vector (bind with the preceding NetworkComponent) +    void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0) +    { +      if (pos==0) +        mpInput=&rInput; +      if (pos>=0 && pos<size) +        InputVec[pos]=&rInput; +    } + +    void Propagate()      { -      if (NULL == mpErrorInput) Error("mpErrorInput is NULL");        int loc=0; +      mOutput.Init(GetInput.Rows(),GetNOutput());        for (int i=0;i<size;++i)        { -        ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]); -        loc+=SectLen[i]; +        mOutput.CopyCols(SectLen[i], 0, GetInput(i), loc); +        loc += SectLen[i];        }      } -      -   protected: -    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) +    void Backpropagate()      { +      if (NULL == mpErrorInput) Error("mpErrorInput is NULL");        int loc=0;        for (int i=0;i<size;++i)        { -        Y.CopyCols(SectLen[i], 0, X, loc); +        ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]);          loc+=SectLen[i];        }      } +      +   protected: +     +    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) +    { Error("__func__ Nonsense"); }      void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)      { Error("__func__ Nonsense"); } | 
