diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuCompDisc.cc | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/CuTNetLib/cuCompDisc.cc')
-rw-r--r-- | src/CuTNetLib/cuCompDisc.cc | 178 |
1 files changed, 178 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuCompDisc.cc b/src/CuTNetLib/cuCompDisc.cc new file mode 100644 index 0000000..2336a86 --- /dev/null +++ b/src/CuTNetLib/cuCompDisc.cc @@ -0,0 +1,178 @@ + + +#include "cuCompDisc.h" +#include "cuNetwork.h" + +#include "Error.h" + + +namespace TNet +{ + + void + CuDiscrete:: + Propagate() + { + for (int i=0;i<inID.size(); i++) + mBlocks[inID[i].block]->SetInput(GetInput(i),inID[i].pos); + for (int i=0; i<mBlocks.size(); i++) + mBlocks[i]->Propagate(); + } + + void + CuDiscrete:: + PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + Error("Not applicable"); + } + + void + CuDiscrete:: + Backpropagate() + { + for (int i=0;i<outID.size(); i++) + mBlocks[outID[i].block]->SetErrorInput(GetOutput(i),outID[i].pos); + for(int i=0; i<mBlocks.size(); i++) + mBlocks[i]->Backpropagate(); + } + + void + CuDiscrete:: + BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + Error("Not applicable"); + } + + void + CuDiscrete:: + Update() + { + for(int i=0; i<mBlocks.size(); i++) + if ( mBlocks[i]->IsUpdatable() ) + { + CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]); + rComp.Update(); + } + } + + + void + CuDiscrete:: + ReadFromStream(std::istream& rIn) + { + int i; + for(i=0; i<mBlocks.size(); i++) { + delete mBlocks[i]; + } + mBlocks.clear(); + inID.clear(); + outID.clear(); + CuComponent* comp; + i=0; + while ( NULL != (comp=CuNetwork::ComponentReader(rIn,NULL)) ) + { + mBlocks.push_back(comp); + for (int j=0;j<(comp->GetInSect());++j) + inID.push_back(posID(i,j)); + for (int j=0;j<(comp->GetOutSect());++j) + outID.push_back(posID(i,j)); + ++i; + } + } + + + void + CuDiscrete:: + WriteToStream(std::ostream& rOut) + { + for(int i=0; i<mBlocks.size(); i++) + CuNetwork::ComponentDumper(rOut,*mBlocks[i]); + rOut << "<endblock>\n"; + } + + void + CuCompound:: + PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + int iLoc=0,oLoc=0; + CuMatrix<BaseFloat> In; + CuMatrix<BaseFloat> Out; + for(int i=0; i<mBlocks.size(); i++) + { + In.Init(X,iLoc,mBlocks[i]->GetNInputs()); + Out.Init(Y,oLoc,mBlocks[i]->GetNOutputs()); + mBlocks[i]->PropagateF(In,Out); + iLoc+=mBlocks[i]->GetNInputs(); + oLoc+=mBlocks[i]->GetNOutputs(); + } + } + + void + CuCompound:: + BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + int iLoc=0,oLoc=0; + CuMatrix<BaseFloat> In; + CuMatrix<BaseFloat> Out; + for(int i=0; i<mBlocks.size(); i++) + { + In.Init(X,iLoc,mBlocks[i]->GetNOutputs()); + Out.Init(Y,oLoc,mBlocks[i]->GetNInputs()); + mBlocks[i]->BackpropagateF(In,Out); + iLoc+=mBlocks[i]->GetNOutputs(); + oLoc+=mBlocks[i]->GetNInputs(); + } + } + + void + CuCompound:: + PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + Error("Not applicable"); + } + + void + CuCompound:: + BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) + { + Error("Not applicable"); + } + + void + CuCompound:: + Update() + { + for(int i=0; i<mBlocks.size(); i++) + if ( mBlocks[i]->IsUpdatable() ) + { + CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]); + rComp.Update(); + } + } + + + void + CuCompound:: + ReadFromStream(std::istream& rIn) + { + for(int i=0; i<mBlocks.size(); i++) { + delete mBlocks[i]; + } + mBlocks.clear(); + CuComponent* comp; + while ( NULL != (comp=CuNetwork::ComponentReader(rIn,NULL)) ) + mBlocks.push_back(comp); + } + + + void + CuCompound:: + WriteToStream(std::ostream& rOut) + { + for(int i=0; i<mBlocks.size(); i++) + CuNetwork::ComponentDumper(rOut,*mBlocks[i]); + rOut << "<endblock>\n"; + } + +} //namespace + |