#ifndef _CUNETWORK_COMPONENT_I_H
#define _CUNETWORK_COMPONENT_I_H


#include "Vector.h"
#include "Matrix.h"
#include "Error.h"

#include "cumatrix.h"

#include <iostream>
#include <stdexcept>
#include <vector>

/// \defgroup CuNNLayer CuNN Layer types
/// \ingroup CuNNComp

/// \defgroup CuNNUpdatable CuNN Updatable Layer
/// \ingroup CuNNLayer

/// \defgroup CuNNActivation CuNN Activation Func Layer
/// \ingroup CuNNLayer

/// \defgroup CuNNMisc CuNN Misc Layer
/// \ingroup CuNNLayer

namespace TNet {

    
  /**
   * \brief Neural network building blocks
   *
   * Basic element of the network,
   * it is a box with defined inputs and outputs, 
   * and functions to refresh outputs
   *
   * it is able to compute tranformation function (forward pass) 
   * and jacobian function (backward pass), 
   * which is to be implemented in descendents
   *
   * Compoments together form a doubly-linked list.
   * 
   * Every input rows are a frame of data,
   *  and the input size are the number of columns 
   */ 
  class CuComponent 
  {
    public:
    /// Types of the net components
    typedef enum { 
      UPDATABLE_COMPONENT = 0x0100,
      BIASED_LINEARITY,
      DISCRETE_LINEARITY,
      SHARED_LINEARITY,
      SPARSE_LINEARITY,
      RBM,
      RBM_SPARSE,
      RECURRENT,
      LINEARITY,
      UPDATABLEBIAS,
      DISCRETE,
      COMPOUND,

      ACT_FUN = 0x0200,
      SOFTMAX, 
      SIGMOID,

      OTHER = 0x0400,
      EXPAND,
      COPY,
      TRANSPOSE,
      BLOCK_LINEARITY,
      WINDOW,
      BIAS,
      LOG,
      PIPE,
      LEARNSTOP,
      DISTRIB,
      COMBINE,
      DIVIDE,
      MERGE,
      REORDER,

      BLOCK_ARRAY,
    } ComponentType;
    
    typedef std::vector< CuMatrix<BaseFloat>* > MatrixPtrVec;
    typedef std::vector< const CuMatrix<BaseFloat>* > ConstMatrixPtrVec;

    //////////////////////////////////////////////////////////////
    // Constructor & Destructor
    public: 
      CuComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      virtual ~CuComponent();  
       
    //////////////////////////////////////////////////////////////
    // Interface specification (public)
    public:
      /// Get Type Identification of the component
      virtual ComponentType GetType() const = 0;  
      /// Get Type Label of the component
      virtual const char* GetName() const = 0;
      /// Return if the component is UpdatableComponent
      virtual bool IsUpdatable() const 
      { return false; }

      /// Get size of input vectors
      size_t GetNInputs() const;  
      /// Get size of output vectors 
      size_t GetNOutputs() const;

      /// Set size of input vectors
      size_t SetNInputs(size_t nInputs);
      /// Set size of output vectors
      size_t SetNOutputs(size_t nOutputs);
      /// Set the previous component
      void SetPrevious(CuComponent* pPred);
      /// Set the next component
      void SetNext(CuComponent* pNxt);
      
      /// Return the number of different inputs for complex component
      virtual int GetInSect(){return 1;}
      /// Return the number of different outputs for complex component
      virtual int GetOutSect(){return 1;}

      /// IO Data getters
      virtual const CuMatrix<BaseFloat>& GetInput(int pos=0)  {
    if (NULL == mpInput) Error("mpInput is NULL");
    return *mpInput;
  }
      virtual const CuMatrix<BaseFloat>& GetOutput(int pos=0)  {
    return mOutput;
  }
      virtual const CuMatrix<BaseFloat>& GetErrorInput(int pos=0)  {
    if (NULL == mpErrorInput) Error("mpErrorInput is NULL");
    return *mpErrorInput;
  }
      virtual const CuMatrix<BaseFloat>& GetErrorOutput(int pos=0)  {
    return mErrorOutput;
  }

      /// Set input vector (bind with the preceding NetworkComponent)
      virtual void SetInput(const CuMatrix<BaseFloat>& rInput,int pos=0)  {
    mpInput = &rInput;
  }
      /// Set error input vector (bind with the following NetworkComponent) 
      virtual void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0)  {
    mpErrorInput = &rErrorInput;
  }  

      /// Perform forward pass propagateion Input->Output,
      /// wrapper for the PropagateFnc method
      virtual void Propagate()  {
    //initialize output buffer
    mOutput.Init(GetInput().Rows(),GetNOutputs());
    //do the dimensionality test
    if(GetNInputs() != GetInput().Cols()) {
      KALDI_ERR << "Non-matching INPUT dim!!! Network dim: " << GetNInputs() 
                << " Data dim: " << GetInput().Cols();
    }
    //run transform
    PropagateF(GetInput(),mOutput);
  }
      /// Perform backward pass propagateion ErrorInput->ErrorOutput,
      /// wrapper for the BackpropagateFnc method
      virtual void Backpropagate()  {
    //re-initialize the output buffer
    mErrorOutput.Init(GetErrorInput().Rows(),GetNInputs());

    //do the dimensionality test
    assert(GetErrorInput().Cols() == mNOutputs);
    assert(mErrorOutput.Cols() == mNInputs);
    assert(mErrorOutput.Rows() == GetErrorInput().Rows());

    //transform
    BackpropagateF(GetErrorInput(),mErrorOutput);
  }
 
      /// Reads the component parameters from stream
      virtual void ReadFromStream(std::istream& rIn)  { }
      /// Writes the components parameters to stream
      virtual void WriteToStream(std::ostream& rOut)  { } 
      
      /// Public wrapper for PropagateFnc
      virtual void PropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
	  PropagateFnc(X,Y);
  }
  
      /// Public wrapper for BackpropagateFnc
      virtual void BackpropagateF(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y)
  {
	  BackpropagateFnc(X,Y);
  }

    ///////////////////////////////////////////////////////////////
    // Nonpublic member functions used to update data outputs 
    protected:
      /// Forward pass transformation (to be implemented by descendents...)
      /// \param[in] X InputMatrix (Network input or Output from last layer)
      /// \param[out] Y OutputMatrix (Network output or input of the next layer)
      virtual void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) = 0;
      /// Backward pass transformation (to be implemented by descendents...)
      /// \param[in] X InputMatrix (Network Error, objective func output, or Error output from the next layer)
      /// \param[out] Y OutputMatrix (Error input of the last layer)
      virtual void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y) = 0;
      
   
    ///////////////////////////////////////////////////////////////
    // data members
    protected:

      size_t mNInputs;  ///< Size of input vectors
      size_t mNOutputs; ///< Size of output vectors 
      
      const CuMatrix<BaseFloat>* mpInput; ///< inputs are NOT OWNED by component
      const CuMatrix<BaseFloat>* mpErrorInput;///< inputs are NOT OWNED by component

      CuMatrix<BaseFloat> mOutput; ///< outputs are OWNED by component
      CuMatrix<BaseFloat> mErrorOutput; ///< outputs are OWNED by component
      
      CuComponent* preComp;///< The preceding component in the Network
      CuComponent* nxtComp;///< The following component in the Network
  };


  /**
   * \brief Class UpdatableComponent is a box which has some 
   * parameters adjustable by learning
   * 
   * you can set the learning rate, lock the params,
   * and learn from each data observation
   */
  class CuUpdatableComponent : public CuComponent
  {
    //////////////////////////////////////////////////////////////
    // Constructor & Destructor
    public: 
      CuUpdatableComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      virtual ~CuUpdatableComponent();


    //////////////////////////////////////////////////////////////
    // Interface specification (public)
    public:
      /// Return if CuUpdatableComponent is updatable? 
      virtual bool IsUpdatable() const 
      { return true; }

      /// get gradient and update the parameters in one step
      virtual void Update() = 0;    
      
      /// Sets the learning rate of gradient descent
      void LearnRate(BaseFloat rate);
      /// Gets the learning rate of gradient descent
      BaseFloat LearnRate();

      /// Sets the momentum
      void Momentum(BaseFloat mmt);
      BaseFloat Momentum();

      /// Set the weight decay rate to penalize large weights
      void Weightcost(BaseFloat cost);
      BaseFloat Weightcost();

      /// Set whether gradient is divided by frames
      void GradDivFrm(bool div);
      bool GradDivFrm();
  
    protected:
      BaseFloat mLearningRate;
      BaseFloat mMomentum;
      BaseFloat mWeightcost;
      bool      mGradDivFrm;

  };

  //////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuComponent::
  inline
  CuComponent::
  CuComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred) 
    : mNInputs(nInputs), mNOutputs(nOutputs), 
      mpInput(NULL), mpErrorInput(NULL), 
      mOutput(), mErrorOutput(),preComp(pPred)
  {
    /* DOUBLE LINK the Components */
    if (pPred != NULL) {
      SetPrevious(pPred);
      pPred->SetNext(this);
    }
  } 

  inline void
  CuComponent::
  SetPrevious(CuComponent* pPred)
  {
	preComp=pPred;
    /* DOUBLE LINK the Components */
    if (pPred != NULL) {
      SetInput(pPred->GetOutput());
    }
  }
  
  inline void
  CuComponent::
  SetNext(CuComponent* pNxt)
  {
	nxtComp=pNxt;
	if (pNxt != NULL) {
      SetErrorInput(pNxt->GetErrorOutput());
    }
  }

  inline
  CuComponent::
  ~CuComponent()
  {
    ;
  }
  
  inline size_t
  CuComponent::
  GetNInputs() const
  {
    return mNInputs;
  }

  inline size_t
  CuComponent::
  GetNOutputs() const
  {
    return mNOutputs;
  }

  inline size_t
  CuComponent::
  SetNInputs(size_t nInputs)
  {
    mNInputs=nInputs;
  }
  
  inline size_t
  CuComponent::
  SetNOutputs(size_t nOutputs)
  {
    mNOutputs=nOutputs;
  }
  
  


  //////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // UpdatableComponent::
  
  inline 
  CuUpdatableComponent::
  CuUpdatableComponent(size_t nInputs, size_t nOutputs, CuComponent *pPred) 
    : CuComponent(nInputs, nOutputs, pPred), 
      mLearningRate(0.0), mMomentum(0), mWeightcost(0), mGradDivFrm(true)
  {
    ; 
  } 


  inline
  CuUpdatableComponent::
  ~CuUpdatableComponent()
  {
    ;
  }


  inline void
  CuUpdatableComponent::
  LearnRate(BaseFloat rate)
  {
    mLearningRate = rate;
  }


  inline BaseFloat
  CuUpdatableComponent::
  LearnRate()
  {
    return mLearningRate;
  }
  

  inline void
  CuUpdatableComponent::
  Momentum(BaseFloat mmt)
  {
    mMomentum = mmt;
  }


  inline BaseFloat
  CuUpdatableComponent::
  Momentum()
  {
    return mMomentum;
  }
  
  
  inline void
  CuUpdatableComponent::
  Weightcost(BaseFloat cost)
  {
    mWeightcost = cost;
  }


  inline BaseFloat
  CuUpdatableComponent::
  Weightcost()
  {
    return mWeightcost;
  }

  
  inline void 
  CuUpdatableComponent::
  GradDivFrm(bool div)
  {
    mGradDivFrm = div;
  }
   
  inline bool
  CuUpdatableComponent::
  GradDivFrm()
  {
    return mGradDivFrm;
  }

} // namespace TNet


#endif