From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- src/CuTNetLib/cuCompDisc.cc | 178 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 src/CuTNetLib/cuCompDisc.cc (limited to 'src/CuTNetLib/cuCompDisc.cc') diff --git a/src/CuTNetLib/cuCompDisc.cc b/src/CuTNetLib/cuCompDisc.cc new file mode 100644 index 0000000..2336a86 --- /dev/null +++ b/src/CuTNetLib/cuCompDisc.cc @@ -0,0 +1,178 @@ + + +#include "cuCompDisc.h" +#include "cuNetwork.h" + +#include "Error.h" + + +namespace TNet +{ + + void + CuDiscrete:: + Propagate() + { + for (int i=0;iSetInput(GetInput(i),inID[i].pos); + for (int i=0; iPropagate(); + } + + void + CuDiscrete:: + PropagateFnc(const CuMatrix& X, CuMatrix& Y) + { + Error("Not applicable"); + } + + void + CuDiscrete:: + Backpropagate() + { + for (int i=0;iSetErrorInput(GetOutput(i),outID[i].pos); + for(int i=0; iBackpropagate(); + } + + void + CuDiscrete:: + BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) + { + Error("Not applicable"); + } + + void + CuDiscrete:: + Update() + { + for(int i=0; iIsUpdatable() ) + { + CuUpdatableComponent& rComp = dynamic_cast(*mBlocks[i]); + rComp.Update(); + } + } + + + void + CuDiscrete:: + ReadFromStream(std::istream& rIn) + { + int i; + for(i=0; iGetInSect());++j) + inID.push_back(posID(i,j)); + for (int j=0;j<(comp->GetOutSect());++j) + outID.push_back(posID(i,j)); + ++i; + } + } + + + void + CuDiscrete:: + WriteToStream(std::ostream& rOut) + { + for(int i=0; i\n"; + } + + void + CuCompound:: + PropagateF(CuMatrix& X, CuMatrix& Y) + { + int iLoc=0,oLoc=0; + CuMatrix In; + CuMatrix Out; + for(int i=0; iGetNInputs()); + Out.Init(Y,oLoc,mBlocks[i]->GetNOutputs()); + mBlocks[i]->PropagateF(In,Out); + iLoc+=mBlocks[i]->GetNInputs(); + oLoc+=mBlocks[i]->GetNOutputs(); + } + } + + void + CuCompound:: + BackpropagateF(CuMatrix& X, CuMatrix& Y) + { + int iLoc=0,oLoc=0; + CuMatrix In; + CuMatrix Out; + for(int i=0; iGetNOutputs()); + Out.Init(Y,oLoc,mBlocks[i]->GetNInputs()); + mBlocks[i]->BackpropagateF(In,Out); + iLoc+=mBlocks[i]->GetNOutputs(); + oLoc+=mBlocks[i]->GetNInputs(); + } + } + + void + CuCompound:: + PropagateFnc(const CuMatrix& X, CuMatrix& Y) + { + Error("Not applicable"); + } + + void + CuCompound:: + BackpropagateFnc(const CuMatrix& X, CuMatrix& Y) + { + Error("Not applicable"); + } + + void + CuCompound:: + Update() + { + for(int i=0; iIsUpdatable() ) + { + CuUpdatableComponent& rComp = dynamic_cast(*mBlocks[i]); + rComp.Update(); + } + } + + + void + CuCompound:: + ReadFromStream(std::istream& rIn) + { + for(int i=0; i\n"; + } + +} //namespace + -- cgit v1.2.3-70-g09d2