diff options
Diffstat (limited to 'src/CuTNetLib/cuComponent.h')
-rw-r--r-- | src/CuTNetLib/cuComponent.h | 45 |
1 files changed, 23 insertions, 22 deletions
diff --git a/src/CuTNetLib/cuComponent.h b/src/CuTNetLib/cuComponent.h index 6cc8462..fc9666c 100644 --- a/src/CuTNetLib/cuComponent.h +++ b/src/CuTNetLib/cuComponent.h @@ -85,6 +85,7 @@ namespace TNet { } ComponentType; typedef std::vector< CuMatrix<BaseFloat>* > MatrixPtrVec; + typedef std::vector< const CuMatrix<BaseFloat>* > ConstMatrixPtrVec; ////////////////////////////////////////////////////////////// // Constructor & Destructor @@ -118,20 +119,20 @@ namespace TNet { void SetNext(CuComponent* pNxt); /// Return the number of different inputs for complex component - int GetInSect(); + int GetInSect() const; /// Return the number of different outputs for complex component - int GetOutSect(); + int GetOutSect() const; /// IO Data getters - CuMatrix<BaseFloat>& GetInput(int pos=0); - CuMatrix<BaseFloat>& GetOutput(int pos=0); - CuMatrix<BaseFloat>& GetErrorInput(int pos=0); - CuMatrix<BaseFloat>& GetErrorOutput(int pos=0); + const CuMatrix<BaseFloat>& GetInput(int pos=0); + const CuMatrix<BaseFloat>& GetOutput(int pos=0); + const CuMatrix<BaseFloat>& GetErrorInput(int pos=0); + const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0); /// Set input vector (bind with the preceding NetworkComponent) - void SetInput(CuMatrix<BaseFloat>& rInput,int pos=0); + void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0); /// Set error input vector (bind with the following NetworkComponent) - void SetErrorInput(CuMatrix<BaseFloat>& rErrorInput,int pos=0); + void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0); /// Perform forward pass propagateion Input->Output, /// wrapper for the PropagateFnc method @@ -146,9 +147,9 @@ namespace TNet { virtual void WriteToStream(std::ostream& rOut) { } /// Public wrapper for PropagateFnc - void PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + void PropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); /// Public wrapper for BackpropagateFnc - void BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + void BackpropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); /////////////////////////////////////////////////////////////// @@ -171,8 +172,8 @@ namespace TNet { size_t mNInputs; ///< Size of input vectors size_t mNOutputs; ///< Size of output vectors - CuMatrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component - CuMatrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component + const CuMatrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component + const CuMatrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component CuMatrix<BaseFloat> mOutput; ///< outputs are OWNED by component CuMatrix<BaseFloat> mErrorOutput; ///< outputs are OWNED by component @@ -316,7 +317,7 @@ namespace TNet { inline void CuComponent:: - SetInput(CuMatrix<BaseFloat>& rInput,int pos) + SetInput(const CuMatrix<BaseFloat>& rInput,int pos) { mpInput = &rInput; } @@ -324,12 +325,12 @@ namespace TNet { inline void CuComponent:: - SetErrorInput(CuMatrix<BaseFloat>& rErrorInput,int pos) + SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos) { mpErrorInput = &rErrorInput; } - inline CuMatrix<BaseFloat>& + inline const CuMatrix<BaseFloat>& CuComponent:: GetInput(int pos) { @@ -337,14 +338,14 @@ namespace TNet { return *mpInput; } - inline CuMatrix<BaseFloat>& + inline const CuMatrix<BaseFloat>& CuComponent:: GetOutput(int pos) { return mOutput; } - inline CuMatrix<BaseFloat>& + inline const CuMatrix<BaseFloat>& CuComponent:: GetErrorInput(int pos) { @@ -352,7 +353,7 @@ namespace TNet { return *mpErrorInput; } - inline CuMatrix<BaseFloat>& + inline const CuMatrix<BaseFloat>& CuComponent:: GetErrorOutput(int pos) { @@ -375,14 +376,14 @@ namespace TNet { inline int CuComponent:: - GetInSect() + GetInSect() const { return 1; } inline int CuComponent:: - GetOutSect() + GetOutSect() const { return 1; } @@ -403,13 +404,13 @@ namespace TNet { inline void CuComponent:: - PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + PropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { PropagateFnc(X,Y); } inline void CuComponent:: - BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + BackpropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { BackpropagateFnc(X,Y); } |