summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuCompDisc.cc
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
commitcccccbf6cca94a3eaf813b4468453160e91c332b (patch)
tree23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuCompDisc.cc
downloadtnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip
First commit
Diffstat (limited to 'src/CuTNetLib/cuCompDisc.cc')
-rw-r--r--src/CuTNetLib/cuCompDisc.cc178
1 files changed, 178 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuCompDisc.cc b/src/CuTNetLib/cuCompDisc.cc
new file mode 100644
index 0000000..2336a86
--- /dev/null
+++ b/src/CuTNetLib/cuCompDisc.cc
@@ -0,0 +1,178 @@
+
+
+#include "cuCompDisc.h"
+#include "cuNetwork.h"
+
+#include "Error.h"
+
+
+namespace TNet
+{
+
+ void
+ CuDiscrete::
+ Propagate()
+ {
+ for (int i=0;i<inID.size(); i++)
+ mBlocks[inID[i].block]->SetInput(GetInput(i),inID[i].pos);
+ for (int i=0; i<mBlocks.size(); i++)
+ mBlocks[i]->Propagate();
+ }
+
+ void
+ CuDiscrete::
+ PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ Error("Not applicable");
+ }
+
+ void
+ CuDiscrete::
+ Backpropagate()
+ {
+ for (int i=0;i<outID.size(); i++)
+ mBlocks[outID[i].block]->SetErrorInput(GetOutput(i),outID[i].pos);
+ for(int i=0; i<mBlocks.size(); i++)
+ mBlocks[i]->Backpropagate();
+ }
+
+ void
+ CuDiscrete::
+ BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ Error("Not applicable");
+ }
+
+ void
+ CuDiscrete::
+ Update()
+ {
+ for(int i=0; i<mBlocks.size(); i++)
+ if ( mBlocks[i]->IsUpdatable() )
+ {
+ CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]);
+ rComp.Update();
+ }
+ }
+
+
+ void
+ CuDiscrete::
+ ReadFromStream(std::istream& rIn)
+ {
+ int i;
+ for(i=0; i<mBlocks.size(); i++) {
+ delete mBlocks[i];
+ }
+ mBlocks.clear();
+ inID.clear();
+ outID.clear();
+ CuComponent* comp;
+ i=0;
+ while ( NULL != (comp=CuNetwork::ComponentReader(rIn,NULL)) )
+ {
+ mBlocks.push_back(comp);
+ for (int j=0;j<(comp->GetInSect());++j)
+ inID.push_back(posID(i,j));
+ for (int j=0;j<(comp->GetOutSect());++j)
+ outID.push_back(posID(i,j));
+ ++i;
+ }
+ }
+
+
+ void
+ CuDiscrete::
+ WriteToStream(std::ostream& rOut)
+ {
+ for(int i=0; i<mBlocks.size(); i++)
+ CuNetwork::ComponentDumper(rOut,*mBlocks[i]);
+ rOut << "<endblock>\n";
+ }
+
+ void
+ CuCompound::
+ PropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ int iLoc=0,oLoc=0;
+ CuMatrix<BaseFloat> In;
+ CuMatrix<BaseFloat> Out;
+ for(int i=0; i<mBlocks.size(); i++)
+ {
+ In.Init(X,iLoc,mBlocks[i]->GetNInputs());
+ Out.Init(Y,oLoc,mBlocks[i]->GetNOutputs());
+ mBlocks[i]->PropagateF(In,Out);
+ iLoc+=mBlocks[i]->GetNInputs();
+ oLoc+=mBlocks[i]->GetNOutputs();
+ }
+ }
+
+ void
+ CuCompound::
+ BackpropagateF(CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ int iLoc=0,oLoc=0;
+ CuMatrix<BaseFloat> In;
+ CuMatrix<BaseFloat> Out;
+ for(int i=0; i<mBlocks.size(); i++)
+ {
+ In.Init(X,iLoc,mBlocks[i]->GetNOutputs());
+ Out.Init(Y,oLoc,mBlocks[i]->GetNInputs());
+ mBlocks[i]->BackpropagateF(In,Out);
+ iLoc+=mBlocks[i]->GetNOutputs();
+ oLoc+=mBlocks[i]->GetNInputs();
+ }
+ }
+
+ void
+ CuCompound::
+ PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ Error("Not applicable");
+ }
+
+ void
+ CuCompound::
+ BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
+ {
+ Error("Not applicable");
+ }
+
+ void
+ CuCompound::
+ Update()
+ {
+ for(int i=0; i<mBlocks.size(); i++)
+ if ( mBlocks[i]->IsUpdatable() )
+ {
+ CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]);
+ rComp.Update();
+ }
+ }
+
+
+ void
+ CuCompound::
+ ReadFromStream(std::istream& rIn)
+ {
+ for(int i=0; i<mBlocks.size(); i++) {
+ delete mBlocks[i];
+ }
+ mBlocks.clear();
+ CuComponent* comp;
+ while ( NULL != (comp=CuNetwork::ComponentReader(rIn,NULL)) )
+ mBlocks.push_back(comp);
+ }
+
+
+ void
+ CuCompound::
+ WriteToStream(std::ostream& rOut)
+ {
+ for(int i=0; i<mBlocks.size(); i++)
+ CuNetwork::ComponentDumper(rOut,*mBlocks[i]);
+ rOut << "<endblock>\n";
+ }
+
+} //namespace
+