From 0653afe62c109f333677f9fd90d19d4727e7cca5 Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 11:15:21 +0800 Subject: Supporting const rev. --- src/CuTNetLib/cuComponent.h | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) (limited to 'src/CuTNetLib/cuComponent.h') diff --git a/src/CuTNetLib/cuComponent.h b/src/CuTNetLib/cuComponent.h index 6cc8462..fc9666c 100644 --- a/src/CuTNetLib/cuComponent.h +++ b/src/CuTNetLib/cuComponent.h @@ -85,6 +85,7 @@ namespace TNet { } ComponentType; typedef std::vector< CuMatrix* > MatrixPtrVec; + typedef std::vector< const CuMatrix* > ConstMatrixPtrVec; ////////////////////////////////////////////////////////////// // Constructor & Destructor @@ -118,20 +119,20 @@ namespace TNet { void SetNext(CuComponent* pNxt); /// Return the number of different inputs for complex component - int GetInSect(); + int GetInSect() const; /// Return the number of different outputs for complex component - int GetOutSect(); + int GetOutSect() const; /// IO Data getters - CuMatrix& GetInput(int pos=0); - CuMatrix& GetOutput(int pos=0); - CuMatrix& GetErrorInput(int pos=0); - CuMatrix& GetErrorOutput(int pos=0); + const CuMatrix& GetInput(int pos=0); + const CuMatrix& GetOutput(int pos=0); + const CuMatrix& GetErrorInput(int pos=0); + const CuMatrix& GetErrorOutput(int pos=0); /// Set input vector (bind with the preceding NetworkComponent) - void SetInput(CuMatrix& rInput,int pos=0); + void SetInput(const CuMatrix& rInput,int pos=0); /// Set error input vector (bind with the following NetworkComponent) - void SetErrorInput(CuMatrix& rErrorInput,int pos=0); + void SetErrorInput(const CuMatrix& rErrorInput,int pos=0); /// Perform forward pass propagateion Input->Output, /// wrapper for the PropagateFnc method @@ -146,9 +147,9 @@ namespace TNet { virtual void WriteToStream(std::ostream& rOut) { } /// Public wrapper for PropagateFnc - void PropagateF(CuMatrix& X, CuMatrix& Y); + void PropagateF(const CuMatrix& X, CuMatrix& Y); /// Public wrapper for BackpropagateFnc - void BackpropagateF(CuMatrix& X, CuMatrix& Y); + void BackpropagateF(const CuMatrix& X, CuMatrix& Y); /////////////////////////////////////////////////////////////// @@ -171,8 +172,8 @@ namespace TNet { size_t mNInputs; ///< Size of input vectors size_t mNOutputs; ///< Size of output vectors - CuMatrix* mpInput; ///< inputs are NOT OWNED by component - CuMatrix* mpErrorInput;///< inputs are NOT OWNED by component + const CuMatrix* mpInput; ///< inputs are NOT OWNED by component + const CuMatrix* mpErrorInput;///< inputs are NOT OWNED by component CuMatrix mOutput; ///< outputs are OWNED by component CuMatrix mErrorOutput; ///< outputs are OWNED by component @@ -316,7 +317,7 @@ namespace TNet { inline void CuComponent:: - SetInput(CuMatrix& rInput,int pos) + SetInput(const CuMatrix& rInput,int pos) { mpInput = &rInput; } @@ -324,12 +325,12 @@ namespace TNet { inline void CuComponent:: - SetErrorInput(CuMatrix& rErrorInput,int pos) + SetErrorInput(const CuMatrix& rErrorInput,int pos) { mpErrorInput = &rErrorInput; } - inline CuMatrix& + inline const CuMatrix& CuComponent:: GetInput(int pos) { @@ -337,14 +338,14 @@ namespace TNet { return *mpInput; } - inline CuMatrix& + inline const CuMatrix& CuComponent:: GetOutput(int pos) { return mOutput; } - inline CuMatrix& + inline const CuMatrix& CuComponent:: GetErrorInput(int pos) { @@ -352,7 +353,7 @@ namespace TNet { return *mpErrorInput; } - inline CuMatrix& + inline const CuMatrix& CuComponent:: GetErrorOutput(int pos) { @@ -375,14 +376,14 @@ namespace TNet { inline int CuComponent:: - GetInSect() + GetInSect() const { return 1; } inline int CuComponent:: - GetOutSect() + GetOutSect() const { return 1; } @@ -403,13 +404,13 @@ namespace TNet { inline void CuComponent:: - PropagateF(CuMatrix& X, CuMatrix& Y) + PropagateF(const CuMatrix& X, CuMatrix& Y) { PropagateFnc(X,Y); } inline void CuComponent:: - BackpropagateF(CuMatrix& X, CuMatrix& Y) + BackpropagateF(const CuMatrix& X, CuMatrix& Y) { BackpropagateFnc(X,Y); } -- cgit v1.2.3-70-g09d2