#include "cuCompDisc.h"
#include "cuNetwork.h"

#include "Error.h"

#include <stdio.h>

namespace TNet
{
  
  void 
  CuDiscrete::
  Propagate()
  {
    for (int i=0;i<inID.size(); i++)
      mBlocks[inID[i].block]->SetInput(GetInput(i),inID[i].pos);
    for (int i=0; i<mBlocks.size(); i++)
      mBlocks[i]->Propagate();
    mOutput.Init(GetOutput());
  }

  void 
  CuDiscrete::
  PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    Error("Not applicable");
  }

  void 
  CuDiscrete::
  Backpropagate()
  {
    for (int i=0;i<outID.size(); i++)
      mBlocks[outID[i].block]->SetErrorInput(GetErrorInput(i),outID[i].pos);
    for(int i=0; i<mBlocks.size(); i++)
      mBlocks[i]->Backpropagate();
    mErrorOutput.Init(GetErrorOutput());
  }
  
  void 
  CuDiscrete::
  BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    Error("Not applicable");
  }

  void 
  CuDiscrete::
  Update() 
  {
    for(int i=0; i<mBlocks.size(); i++)
    if ( mBlocks[i]->IsUpdatable() )
    {
      CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]);
      rComp.Update();
    }
  }


  void
  CuDiscrete::
  ReadFromStream(std::istream& rIn)
  {
    int i;
    for(i=0; i<mBlocks.size(); i++) {
      delete mBlocks[i];
    }
    mBlocks.clear();
    inID.clear();
    outID.clear();
    CuComponent* comp;
    i=0;
    while ( NULL != (comp=CuNetwork::ComponentReader(rIn,NULL)) )
    {
      mBlocks.push_back(comp);
      for (int j=0;j<(comp->GetInSect());++j)
        inID.push_back(posID(i,j));
      for (int j=0;j<(comp->GetOutSect());++j)
        outID.push_back(posID(i,j));
      ++i;
    }
  }

   
  void
  CuDiscrete::
  WriteToStream(std::ostream& rOut)
  {
    for(int i=0; i<mBlocks.size(); i++)
      CuNetwork::ComponentDumper(rOut,*mBlocks[i]);
    rOut << "<endblock>\n";
  }

  void 
  CuCompound::
  PropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    int iLoc=0,oLoc=0;
    CuMatrix<BaseFloat> In;
    CuMatrix<BaseFloat> Out;
    for(int i=0; i<mBlocks.size(); i++)
    {
      In.Init(X,iLoc,mBlocks[i]->GetNInputs());
      Out.Init(Y,oLoc,mBlocks[i]->GetNOutputs());
      mBlocks[i]->PropagateF(In,Out);
      iLoc+=mBlocks[i]->GetNInputs();
      oLoc+=mBlocks[i]->GetNOutputs();
    }
  }
  
  void 
  CuCompound::
  BackpropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    printf("BP+\n");
    int iLoc=0,oLoc=0;
    CuMatrix<BaseFloat> In;
    CuMatrix<BaseFloat> Out;
    for(int i=0; i<mBlocks.size(); i++)
    {
      In.Init(X,iLoc,mBlocks[i]->GetNOutputs());
      Out.Init(Y,oLoc,mBlocks[i]->GetNInputs());
      mBlocks[i]->BackpropagateF(In,Out);
      iLoc+=mBlocks[i]->GetNOutputs();
      oLoc+=mBlocks[i]->GetNInputs();
    }
    printf("BP-\n");
  }
  
  void 
  CuCompound::
  PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    Error("Not applicable");
  }
  
  void 
  CuCompound::
  BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
    Error("Not applicable");
  }

  void 
  CuCompound::
  Update() 
  {
    for(int i=0; i<mBlocks.size(); i++)
    if ( mBlocks[i]->IsUpdatable() )
    {
      CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]);
      rComp.Update();
    }
  }


  void
  CuCompound::
  ReadFromStream(std::istream& rIn)
  {
    for(int i=0; i<mBlocks.size(); i++) {
      delete mBlocks[i];
    }
    mBlocks.clear();
    CuComponent* comp;
    while ( NULL != (comp=CuNetwork::ComponentReader(rIn,NULL)) )
      mBlocks.push_back(comp);
  }

   
  void
  CuCompound::
  WriteToStream(std::ostream& rOut)
  {
    for(int i=0; i<mBlocks.size(); i++)
      CuNetwork::ComponentDumper(rOut,*mBlocks[i]);
    rOut << "<endblock>\n";
  }
 
} //namespace