summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuComponent.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/CuTNetLib/cuComponent.h')
-rw-r--r--src/CuTNetLib/cuComponent.h45
1 files changed, 23 insertions, 22 deletions
diff --git a/src/CuTNetLib/cuComponent.h b/src/CuTNetLib/cuComponent.h
index 6cc8462..fc9666c 100644
--- a/src/CuTNetLib/cuComponent.h
+++ b/src/CuTNetLib/cuComponent.h
@@ -85,6 +85,7 @@ namespace TNet {
} ComponentType;
typedef std::vector< CuMatrix<BaseFloat>* > MatrixPtrVec;
+ typedef std::vector< const CuMatrix<BaseFloat>* > ConstMatrixPtrVec;
//////////////////////////////////////////////////////////////
// Constructor & Destructor
@@ -118,20 +119,20 @@ namespace TNet {
void SetNext(CuComponent* pNxt);
/// Return the number of different inputs for complex component
- int GetInSect();
+ int GetInSect() const;
/// Return the number of different outputs for complex component
- int GetOutSect();
+ int GetOutSect() const;
/// IO Data getters
- CuMatrix<BaseFloat>& GetInput(int pos=0);
- CuMatrix<BaseFloat>& GetOutput(int pos=0);
- CuMatrix<BaseFloat>& GetErrorInput(int pos=0);
- CuMatrix<BaseFloat>& GetErrorOutput(int pos=0);
+ const CuMatrix<BaseFloat>& GetInput(int pos=0);
+ const CuMatrix<BaseFloat>& GetOutput(int pos=0);
+ const CuMatrix<BaseFloat>& GetErrorInput(int pos=0);
+ const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0);
/// Set input vector (bind with the preceding NetworkComponent)
- void SetInput(CuMatrix<BaseFloat>& rInput,int pos=0);
+ void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0);
/// Set error input vector (bind with the following NetworkComponent)
- void SetErrorInput(CuMatrix<BaseFloat>& rErrorInput,int pos=0);
+ void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0);
/// Perform forward pass propagateion Input->Output,
/// wrapper for the PropagateFnc method
@@ -146,9 +147,9 @@ namespace TNet {
virtual void WriteToStream(std::ostream& rOut) { }
/// Public wrapper for PropagateFnc
- void PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ void PropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
/// Public wrapper for BackpropagateFnc
- void BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ void BackpropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
///////////////////////////////////////////////////////////////
@@ -171,8 +172,8 @@ namespace TNet {
size_t mNInputs; ///< Size of input vectors
size_t mNOutputs; ///< Size of output vectors
- CuMatrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component
- CuMatrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component
+ const CuMatrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component
+ const CuMatrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component
CuMatrix<BaseFloat> mOutput; ///< outputs are OWNED by component
CuMatrix<BaseFloat> mErrorOutput; ///< outputs are OWNED by component
@@ -316,7 +317,7 @@ namespace TNet {
inline void
CuComponent::
- SetInput(CuMatrix<BaseFloat>& rInput,int pos)
+ SetInput(const CuMatrix<BaseFloat>& rInput,int pos)
{
mpInput = &rInput;
}
@@ -324,12 +325,12 @@ namespace TNet {
inline void
CuComponent::
- SetErrorInput(CuMatrix<BaseFloat>& rErrorInput,int pos)
+ SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos)
{
mpErrorInput = &rErrorInput;
}
- inline CuMatrix<BaseFloat>&
+ inline const CuMatrix<BaseFloat>&
CuComponent::
GetInput(int pos)
{
@@ -337,14 +338,14 @@ namespace TNet {
return *mpInput;
}
- inline CuMatrix<BaseFloat>&
+ inline const CuMatrix<BaseFloat>&
CuComponent::
GetOutput(int pos)
{
return mOutput;
}
- inline CuMatrix<BaseFloat>&
+ inline const CuMatrix<BaseFloat>&
CuComponent::
GetErrorInput(int pos)
{
@@ -352,7 +353,7 @@ namespace TNet {
return *mpErrorInput;
}
- inline CuMatrix<BaseFloat>&
+ inline const CuMatrix<BaseFloat>&
CuComponent::
GetErrorOutput(int pos)
{
@@ -375,14 +376,14 @@ namespace TNet {
inline int
CuComponent::
- GetInSect()
+ GetInSect() const
{
return 1;
}
inline int
CuComponent::
- GetOutSect()
+ GetOutSect() const
{
return 1;
}
@@ -403,13 +404,13 @@ namespace TNet {
inline void
CuComponent::
- PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ PropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
{
PropagateFnc(X,Y);
}
inline void
CuComponent::
- BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ BackpropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
{
BackpropagateFnc(X,Y);
}