#ifndef _CUMISC_H_
#define _CUMISC_H_

#include <vector>

#include "cuComponent.h"
#include "cumatrix.h"


#include "Matrix.h"
#include "Vector.h"
#include "Error.h"


namespace TNet {
  /**
   * \brief A pipe for input and errorinput propagation(doesn't incurr copy)
   * 
   * \ingroup CuNNMisc
   */
  class CuPipe : public CuComponent
  {
    public:
    CuPipe(size_t nInputs, size_t nOutputs, CuComponent* pPred)
      : CuComponent(nInputs,nOutputs,pPred)
    { }

    ~CuPipe()
    { }

    ComponentType GetType() const
    { return PIPE; }

    const char* GetName() const
    { return "<pipe>"; }
   
    void ReadFromStream(std::istream& rIn)
    { }

    void WriteToStream(std::ostream& rOut)  
    { }
    
    void Propagate()
    {
      if (NULL == mpInput) Error("mpInput is NULL");
      mOutput.Init(*mpInput);
    }
    void Backpropagate()
    {
      if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
      mErrorOutput.Init(*mpErrorInput);
    }
     
   protected:
    
    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Y.CopyFrom(X);}

    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Y.CopyFrom(X); }

  };
  
  /**
   * \brief A pipe for input propagation(doesn't incurr copy) and set any error to zero
   * 
   * \ingroup CuNNMisc
   * 
   * @todo have to be set to zero on every pass(setup a common zeroed space?!)
   */
  class CuLearnStop : public CuComponent
  {
    public:
    CuLearnStop(size_t nInputs, size_t nOutputs, CuComponent* pPred)
      : CuComponent(nInputs,nOutputs,pPred)
    { }

    ~CuLearnStop()
    { }

    ComponentType GetType() const
    { return LEARNSTOP; }

    const char* GetName() const
    { return "<learnstop>"; }
   
    void ReadFromStream(std::istream& rIn)
    { }

    void WriteToStream(std::ostream& rOut)  
    { }
    
    void Propagate()
    {
      if (NULL == mpInput) Error("mpInput is NULL");
      mOutput.Init(*mpInput);
    }
     
   protected:
    
    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Y.CopyFrom(X);}

    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Y.SetZero(); }

  };
  
  /**
   * \brief Distribute the input to several output
   * 
   * \ingroup CuNNMisc
   * 
   */
  class CuDistrib : public CuComponent
  {
    public:
    CuDistrib(size_t nInputs, size_t nOutputs, CuComponent* pPred)
      : CuComponent(nInputs,nOutputs,pPred),size(0),ErrorInputVec()
    {
    }

    ~CuDistrib()
    { }

    ComponentType GetType() const
    { return DISTRIB; }

    const char* GetName() const
    { return "<distrib>"; }
   
    void ReadFromStream(std::istream& rIn)
    {
      rIn >> std::ws >> size;
      ErrorInputVec.clear();
      for (int i=0; i<size;++i)
        ErrorInputVec.push_back(NULL);
    }

    void WriteToStream(std::ostream& rOut)  
    {
      rOut<<size<<std::endl;
    }
    
    void Propagate()
    {
      if (NULL == mpInput) Error("mpInput is NULL");
      mOutput.Init(*mpInput);
    }
    
    int GetOutSect() 
    {
      return size;
    }
    
    const CuMatrix<BaseFloat>& GetErrorInput(int pos=0)
    {
      if (pos>=0 && pos<size)
        return *ErrorInputVec[pos];
      return *ErrorInputVec[0];
    }

    void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0)
    {
      if (pos==0)
        mpErrorInput=&rErrorInput;
      if (pos>=0 && pos<size)
        ErrorInputVec[pos]=&rErrorInput;
    }  
     
   protected:
    
    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Y.CopyFrom(X);}

    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    {
      Y.SetZero();
      for (int i=0;i<size;++i)
        Y.AddScaled(1.0,*ErrorInputVec[i],1.0);
    }
    
    int size;
    ConstMatrixPtrVec ErrorInputVec;
    Vector<BaseFloat> Scale;
  };
  
  /**
   * \brief Combining(Adding) several inputs together
   * 
   * \ingroup CuNNMisc
   * 
   */
  class CuCombine : public CuComponent
  {
    public:
    CuCombine(size_t nInputs, size_t nOutputs, CuComponent* pPred)
      : CuComponent(nInputs,nOutputs,pPred),size(0),InputVec()
    {
    }

    ~CuCombine()
    { }

    ComponentType GetType() const
    { return COMBINE; }

    const char* GetName() const
    { return "<combine>"; }
   
    void ReadFromStream(std::istream& rIn)
    {
      rIn >> std::ws >> size;
      InputVec.clear();
      for (int i=0; i<size;++i)
        InputVec.push_back(NULL);
    }

    void WriteToStream(std::ostream& rOut)  
    {
      rOut<<size<<std::endl;
    }
    
    void Backpropagate()
    {
      if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
      mErrorOutput.Init(*mpErrorInput);
    }
    
    int GetInSect()
    {
      return size;
    }
    
    const CuMatrix<BaseFloat>& GetInput(int pos=0)
    {
      if (pos>=0 && pos<size)
        return *InputVec[pos];
      return *InputVec[0];
    }

    /// Set input vector (bind with the preceding NetworkComponent)
    void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0)
    {
      if (pos==0)
        mpInput=&rInput;
      if (pos>=0 && pos<size)
        InputVec[pos]=&rInput;
    }
     
   protected:
    
    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    {
      Y.SetZero();
      for (int i=0;i<size;++i)
        Y.AddScaled(1.0,*InputVec[i],1.0);
    }

    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    {
      Y.CopyFrom(X);
    }
    
    int size;
    ConstMatrixPtrVec InputVec;
  };
  
  /**
   * \brief Divide the input matrix to several outputs
   * 
   * \ingroup CuNNMisc
   * 
   */
  class CuDivide : public CuComponent
  {
    public:
    CuDivide(size_t nInputs, size_t nOutputs, CuComponent* pPred)
      : CuComponent(nInputs,nOutputs,pPred),size(0)
    { }

    ~CuDivide()
    { }

    ComponentType GetType() const
    { return DIVIDE; }

    const char* GetName() const
    { return "<divide>"; }
    
    int GetOutSect() 
    {
      return size;
    }
   
    void ReadFromStream(std::istream& rIn)
    {
      int len;
      for (int i=0; i<size;++i)
        delete OutputVec[i];
      rIn >> std::ws >> size;
      OutputVec.clear();
      ErrorInputVec.clear();
      for (int i=0; i<size;++i)
      {
        rIn>>len;
        OutputVec.push_back(new CuMatrix<BaseFloat>());
        ErrorInputVec.push_back(NULL);
        SectLen.push_back(len);
      }
    }

    void WriteToStream(std::ostream& rOut)  
    {
      rOut<<size<<" ";
      for (int i=0; i<size;++i)
        rOut<<SectLen[i]<<" ";
      rOut<<std::endl;
    }
    
    const CuMatrix<BaseFloat>& GetErrorInput(int pos=0)
    {
      if (pos>=0 && pos<size)
        return *ErrorInputVec[pos];
      return *ErrorInputVec[0];
    }

    void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0)
    {
      if (pos==0)
        mpErrorInput=&rErrorInput;
      if (pos>=0 && pos<size)
        ErrorInputVec[pos]=&rErrorInput;
    }

    const CuMatrix<BaseFloat>& GetOutput(int pos=0)
    {
      if (pos>=0 && pos<size)
          return *OutputVec[pos];
      return *OutputVec[0];
    }

    void Propagate()
    {
      if (NULL == mpInput) Error("mpInput is NULL");
      int loc=0;
      for (int i=0;i<size;++i)
      {
        OutputVec[i]->Init(*mpInput,loc,SectLen[i]);
        loc+=SectLen[i];
      }
    }

    void Backpropagate()
    {
      int loc=0;
      mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs());
      for (int i=0;i<size;++i)
      {
        mErrorOutput.CopyCols(SectLen[i], 0, GetErrorInput(i), loc);
        loc += SectLen[i];
      }
    }
     
   protected:
    
    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Error("__func__ Nonsense"); }

    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Error("__func__ Nonsense"); }
    
    int size;
    MatrixPtrVec OutputVec;
    ConstMatrixPtrVec ErrorInputVec;
    std::vector<int> SectLen;

  };
  
  /**
   * \brief Merge several input matrices to one single output
   * 
   * \ingroup CuNNMisc
   * 
   */
  class CuMerge : public CuComponent
  {
    public:
    CuMerge(size_t nInputs, size_t nOutputs, CuComponent* pPred)
      : CuComponent(nInputs,nOutputs,pPred)
    { }

    ~CuMerge()
    { }

    ComponentType GetType() const
    { return MERGE; }

    const char* GetName() const
    { return "<merge>"; }
    
    int GetInSect()
    {
      return size;
    }
   
    void ReadFromStream(std::istream& rIn)
    {
      int len;
      for (int i=0; i<size;++i)
        delete ErrorOutputVec[i];
      rIn >> std::ws >> size;
      ErrorOutputVec.clear();
      InputVec.clear();
      for (int i=0; i<size;++i)
      {
        rIn>>len;
        ErrorOutputVec.push_back(new CuMatrix<BaseFloat>());
        InputVec.push_back(NULL);
        SectLen.push_back(len);
      }
    }

    void WriteToStream(std::ostream& rOut)  
    {
      rOut<<size<<" ";
      for (int i=0; i<size;++i)
        rOut<<SectLen[i]<<" ";
      rOut<<std::endl;
    }
    
    const CuMatrix<BaseFloat>& GetInput(int pos=0)
    {
      if (pos>=0 && pos<size)
        return *InputVec[pos];
      return *InputVec[0];
    }

    /// Set input vector (bind with the preceding NetworkComponent)
    void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0)
    {
      if (pos==0)
        mpInput=&rInput;
      if (pos>=0 && pos<size)
        InputVec[pos]=&rInput;
    }

    const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0)
    {
        if (pos>=0 && pos<size)
            return *ErrorOutputVec[pos];
        return *ErrorOutputVec[0];
    }

    void Propagate()
    {
      int loc=0;
      mOutput.Init(GetInput().Rows(),GetNOutputs());
      for (int i=0;i<size;++i)
      {
        mOutput.CopyCols(SectLen[i], 0, GetInput(i), loc);
        loc += SectLen[i];
      }
    }
    
    void Backpropagate()
    {
      if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
      int loc=0;
      for (int i=0;i<size;++i)
      {
        ErrorOutputVec[i]->Init(*mpErrorInput,loc,SectLen[i]);
        loc+=SectLen[i];
      }
    }
     
   protected:
    
    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Error("__func__ Nonsense"); }

    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Error("__func__ Nonsense"); }
    
    int size;
    
    ConstMatrixPtrVec InputVec;
    MatrixPtrVec ErrorOutputVec;
    std::vector<int> SectLen;

  };
  
  /**
   * \brief Reordering several inputs
   * 
   * \ingroup CuNNMisc
   * 
   */
  class CuReorder : public CuComponent
  {
    public:
    CuReorder(size_t nInputs, size_t nOutputs, CuComponent* pPred)
      : CuComponent(nInputs,nOutputs,pPred)
    { }

    ~CuReorder()
    { }

    ComponentType GetType() const
    { return REORDER; }

    const char* GetName() const
    { return "<reorder>"; }
    
    int GetInSect() 
    {
      return size;
    }
      
    int GetOutSect()
    {
      return size;
    }
   
    void ReadFromStream(std::istream& rIn)
    {
      int pos;
      for (int i=0; i<size;++i)
        delete PipeVec[i];
      rIn >> std::ws >> size;
      Order.clear();
      PipeVec.clear();
      for (int i=0; i<size;++i)
      {
        rIn>>pos;
        Order.push_back(pos);
        PipeVec.push_back(new CuPipe(0,0,NULL));
      }
    }

    void WriteToStream(std::ostream& rOut)  
    {
      rOut << size<< " ";
      for (int i=0; i<size;++i)
        rOut<<Order[i]<<" ";
      rOut<<std::endl;
    }
    
    void Propagate()
    {
      if (NULL == mpInput) Error("mpInput is NULL");
      for (int i=0; i<size;++i)
        PipeVec[i]->Propagate();
    }
    
    void Backpropagate()
    {
      if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
      for (int i=0; i<size;++i)
        PipeVec[i]->Backpropagate();
    }
    
    /// IO Data getters
    const CuMatrix<BaseFloat>& GetInput(int pos=0)
    {
      return PipeVec[pos]->GetInput();
    }
    const CuMatrix<BaseFloat>& GetOutput(int pos=0)
    {
      return PipeVec[Order[pos]]->GetOutput();
    }
    const CuMatrix<BaseFloat>& GetErrorInput(int pos=0)
    { 
      return PipeVec[Order[pos]]->GetErrorInput();
    }
    const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0)
    {
      return PipeVec[pos]->GetErrorOutput();
    }

    /// Set input vector (bind with the preceding NetworkComponent)
    void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0)
    {
      if (pos == 0)
        mpInput = &rInput;
      PipeVec[pos]->SetInput(rInput);
    }          
    /// Set error input vector (bind with the following NetworkComponent) 
    void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0)
    {
      if (pos == 0)
        mpErrorInput = &rErrorInput;
      PipeVec[Order[pos]]->SetErrorInput(rErrorInput);
    }
     
   protected:
    
    void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Error("__func__ Nonsense"); }

    void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
    { Error("__func__ Nonsense"); }
    
    int size;
    
    std::vector<int> Order;
    
    std::vector< CuPipe* > PipeVec;
  };
  
} //namespace



#endif