#ifndef _CUCOMPDISC_H_
#define _CUCOMPDISC_H_


#include "cuComponent.h"
#include "cumatrix.h"
#include "cuNetwork.h"

#include "Matrix.h"
#include "Vector.h"
#include "Error.h"

namespace TNet {
  
  /**
   * \brief A layer of updatable compenents
   *
   * \ingroup CuNNUpdatable
   * Each components are individually propagated and backpropagated with discrete inputs and outputs
   *
   * Enabling multipath topological structure within the network by layers
   */
   
  class CuLumpUpdatable : public CuUpdatableComponent
  {
    public:
      CuLumpUpdatable(size_t nInputs, size_t nOutputs, CuComponent *pPred)
        : CuUpdatableComponent(nInputs, nOutputs, pPred)
      { }
  
      void LearnRate(BaseFloat rate)
      {
        mLearningRate = rate;
        for (int i=0;i<mBlocks.size();++i)
          if ( mBlocks[i]->IsUpdatable() )
          {
            CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]);
            rComp.LearnRate(rate);
          }
      }

      //virtual void Propagate(); 
      //virtual void Backpropagate(); 

      void Momentum(BaseFloat mmt)
      {
        mMomentum = mmt;
        for (int i=0;i<mBlocks.size();++i)
          if ( mBlocks[i]->IsUpdatable() )
          {
            CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]);
            rComp.Momentum(mmt);
          }
      }

      void Weightcost(BaseFloat cost)
      {
        mWeightcost = cost;
        for (int i=0;i<mBlocks.size();++i)
          if ( mBlocks[i]->IsUpdatable() )
          {
            CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]);
            rComp.Weightcost(cost);
          }
      }

      void GradDivFrm(bool div)
      {
        mGradDivFrm = div;
        for (int i=0;i<mBlocks.size();++i)
          if ( mBlocks[i]->IsUpdatable() )
          {
            CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(*mBlocks[i]);
            rComp.GradDivFrm(div);
          }
      }

    protected:
      std::vector< CuComponent* > mBlocks; ///< vector with component, one component is one block
  };
  
  /**
   * \brief A layer of updatable compenents
   *
   * \ingroup CuNNUpdatable
   * Each components are individually propagated and backpropagated with inputs and outputs within one matrix to save space
   *
   */
  
  class CuDiscrete : public CuLumpUpdatable
  {
    public:
    
      typedef struct posID{ int block,pos; posID(int b, int p):block(b),pos(p){}} posID;
        

      CuDiscrete(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      ~CuDiscrete();  
      
      ComponentType GetType() const;
      const char* GetName() const;
      
      void Propagate(); 
      void Backpropagate(); 

      void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
      void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);

      void Update();

      void ReadFromStream(std::istream& rIn);
      void WriteToStream(std::ostream& rOut);
      
      int GetInSect() const
      {
        return inID.size();
      }
      
      int GetOutSect() const
      {
        return outID.size();
      }
      
      CuComponent* FindInput(int &pos)
      {
        if (pos<0 or pos>=inID.size())
          Error("Position out of bound");
        int i=pos;
        pos=inID[i].pos;
        return mBlocks[inID[i].block];
      }
      
      CuComponent* FindOutput(int &pos)
      {
        if (pos<0 or pos>=outID.size())
        {
          std::cout<<"pos"<<mBlocks[0].GetOutSect()<<std::endl;
          Error("Position out of bound");
        }
        int i=pos;
        pos=outID[i].pos;
        return mBlocks[outID[i].block];
      }
      
      /// IO Data getters
      const CuMatrix<BaseFloat>& GetInput(int pos=0)
      {
        if (preComp!=NULL)
          return preComp->GetOutput(pos);
        return *mpInput;
      }
      const CuMatrix<BaseFloat>& GetOutput(int pos=0)
      {
        CuComponent* pComp=FindOutput(pos);
        return pComp->GetOutput(pos);
      }
      const CuMatrix<BaseFloat>& GetErrorInput(int pos=0)
      {
        if (nxtComp!=NULL)
          return nxtComp->GetErrorOutput(pos);
        return *mpErrorInput;
      }
      const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0)
      {
        CuComponent* pComp=FindInput(pos);
        return pComp->GetErrorOutput(pos);
      }

      /// Set input vector (bind with the preceding NetworkComponent)
      void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0)
      {
        if (pos==0)
          mpInput=&rInput;
        CuComponent* pComp=FindInput(pos);
        pComp->SetInput(rInput,pos);
      }          
      /// Set error input vector (bind with the following NetworkComponent) 
      void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0)
      {
        if (pos==0)
          mpErrorInput=&rErrorInput;
        CuComponent* pComp=FindOutput(pos);
        pComp->SetErrorInput(rErrorInput,pos);
      }
    private:
      std::vector< CuComponent* > mBlocks;
      std::vector< posID > inID,outID;
  };




  ////////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuDiscrete::
  inline 
  CuDiscrete::
  CuDiscrete(size_t nInputs, size_t nOutputs, CuComponent *pPred)
    : CuLumpUpdatable(nInputs, nOutputs, pPred)
  { }


  inline
  CuDiscrete::
  ~CuDiscrete()
  { 
    for(int i=0; i<mBlocks.size(); i++) {
      delete mBlocks[i];
    }
    mBlocks.clear();
  }

  inline CuComponent::ComponentType
  CuDiscrete::
  GetType() const
  {
    return CuComponent::DISCRETE;
  }

  inline const char*
  CuDiscrete::
  GetName() const
  {
    return "<discrete>";
  }

  class CuCompound : public CuLumpUpdatable
  {
    public:

      CuCompound(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      ~CuCompound();  
      
      ComponentType GetType() const;
      const char* GetName() const;

      void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
      void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);

      void Update();

      void ReadFromStream(std::istream& rIn);
      void WriteToStream(std::ostream& rOut);
      
      void PropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
      void BackpropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);

    protected:
      
      std::vector< CuComponent* > mBlocks; ///< vector with component, one component is one block

  };
  
    ////////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuLinearity::
  inline 
  CuCompound::
  CuCompound(size_t nInputs, size_t nOutputs, CuComponent *pPred)
    : CuLumpUpdatable(nInputs, nOutputs, pPred)
  { }


  inline
  CuCompound::
  ~CuCompound()
  {
    for(int i=0; i<mBlocks.size(); i++) {
      delete mBlocks[i];
    }
    mBlocks.clear(); 
  }

  inline CuComponent::ComponentType
  CuCompound::
  GetType() const
  {
    return CuComponent::COMPOUND;
  }

  inline const char*
  CuCompound::
  GetName() const
  {
    return "<compound>";
  }

} //namespace



#endif