#include "cuCompDisc.h" #include "cuNetwork.h" #include "Error.h" #include <stdio.h> namespace TNet { void CuDiscrete:: Propagate() { for (int i=0;i<inID.size(); i++) mBlocks[inID[i].block]->SetInput(GetInput(i),inID[i].pos); for (int i=0; i<mBlocks.size(); i++) mBlocks[i]->Propagate(); mOutput.Init(GetOutput()); } void CuDiscrete:: PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { Error("Not applicable"); } void CuDiscrete:: Backpropagate() { for (int i=0;i<outID.size(); i++) mBlocks[outID[i].block]->SetErrorInput(GetErrorInput(i),outID[i].pos); for(int i=0; i<mBlocks.size(); i++) mBlocks[i]->Backpropagate(); mErrorOutput.Init(GetErrorOutput()); } void CuDiscrete:: BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { Error("Not applicable"); } void CuDiscrete:: Update() { for(int i=0; i<mBlocks.size(); i++) if ( mBlocks[i]->IsUpdatable() ) { CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]); rComp.Update(); } } void CuDiscrete:: ReadFromStream(std::istream& rIn) { int i; for(i=0; i<mBlocks.size(); i++) { delete mBlocks[i]; } mBlocks.clear(); inID.clear(); outID.clear(); CuComponent* comp; i=0; while ( NULL != (comp=CuNetwork::ComponentReader(rIn,NULL)) ) { mBlocks.push_back(comp); for (int j=0;j<(comp->GetInSect());++j) inID.push_back(posID(i,j)); for (int j=0;j<(comp->GetOutSect());++j) outID.push_back(posID(i,j)); ++i; } } void CuDiscrete:: WriteToStream(std::ostream& rOut) { for(int i=0; i<mBlocks.size(); i++) CuNetwork::ComponentDumper(rOut,*mBlocks[i]); rOut << "<endblock>\n"; } void CuCompound:: PropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { int iLoc=0,oLoc=0; CuMatrix<BaseFloat> In; CuMatrix<BaseFloat> Out; for(int i=0; i<mBlocks.size(); i++) { In.Init(X,iLoc,mBlocks[i]->GetNInputs()); Out.Init(Y,oLoc,mBlocks[i]->GetNOutputs()); mBlocks[i]->PropagateF(In,Out); iLoc+=mBlocks[i]->GetNInputs(); oLoc+=mBlocks[i]->GetNOutputs(); } } void CuCompound:: BackpropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { printf("BP+\n"); int iLoc=0,oLoc=0; CuMatrix<BaseFloat> In; CuMatrix<BaseFloat> Out; for(int i=0; i<mBlocks.size(); i++) { In.Init(X,iLoc,mBlocks[i]->GetNOutputs()); Out.Init(Y,oLoc,mBlocks[i]->GetNInputs()); mBlocks[i]->BackpropagateF(In,Out); iLoc+=mBlocks[i]->GetNOutputs(); oLoc+=mBlocks[i]->GetNInputs(); } printf("BP-\n"); } void CuCompound:: PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { Error("Not applicable"); } void CuCompound:: BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) { Error("Not applicable"); } void CuCompound:: Update() { for(int i=0; i<mBlocks.size(); i++) if ( mBlocks[i]->IsUpdatable() ) { CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]); rComp.Update(); } } void CuCompound:: ReadFromStream(std::istream& rIn) { for(int i=0; i<mBlocks.size(); i++) { delete mBlocks[i]; } mBlocks.clear(); CuComponent* comp; while ( NULL != (comp=CuNetwork::ComponentReader(rIn,NULL)) ) mBlocks.push_back(comp); } void CuCompound:: WriteToStream(std::ostream& rOut) { for(int i=0; i<mBlocks.size(); i++) CuNetwork::ComponentDumper(rOut,*mBlocks[i]); rOut << "<endblock>\n"; } } //namespace