#ifndef _CU_RBM_SPARSE_H_
#define _CU_RBM_SPARSE_H_


#include "cuComponent.h"
#include "cumatrix.h"
#include "cuRbm.h"


#include "Matrix.h"
#include "Vector.h"


namespace TNet {

  class CuRbmSparse : public CuRbmBase
  {
    public:

      CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      ~CuRbmSparse();  
      
      ComponentType GetType() const;
      const char* GetName() const;

      //CuUpdatableComponent API
      void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
      void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);

      void Update();

      //RBM training API
      void Propagate(const CuMatrix<BaseFloat>& visProbs, CuMatrix<BaseFloat>& hidProbs);
      void Reconstruct(const CuMatrix<BaseFloat>& hidState, CuMatrix<BaseFloat>& visProbs);
      void RbmUpdate(const CuMatrix<BaseFloat>& pos_vis, const CuMatrix<BaseFloat>& pos_hid, const CuMatrix<BaseFloat>& neg_vis, const CuMatrix<BaseFloat>& neg_hid);

      RbmUnitType VisType()
      { return mVisType; }

      RbmUnitType HidType()
      { return mHidType; }

      //static void BinarizeProbs(const CuMatrix<BaseFloat>& probs, CuMatrix<BaseFloat>& states);

      //I/O
      void ReadFromStream(std::istream& rIn);
      void WriteToStream(std::ostream& rOut);

    protected:
      CuMatrix<BaseFloat> mVisHid;  ///< Matrix with neuron weights
      CuVector<BaseFloat> mVisBias;       ///< Vector with biases
      CuVector<BaseFloat> mHidBias;       ///< Vector with biases

      CuMatrix<BaseFloat> mVisHidCorrection; ///< Matrix for linearity updates
      CuVector<BaseFloat> mVisBiasCorrection;      ///< Vector for bias updates
      CuVector<BaseFloat> mHidBiasCorrection;      ///< Vector for bias updates

      CuMatrix<BaseFloat> mBackpropErrBuf;

      RbmUnitType mVisType;
      RbmUnitType mHidType;

      ////// sparsity 
      BaseFloat mSparsityPrior; ///< sparsity target (unit activity prior)
      BaseFloat mLambda; ///< exponential decay factor for q (observed probability of unit to be active)
      BaseFloat mSparsityCost; ///< sparsity cost coef.

      CuVector<BaseFloat> mSparsityQ;
      CuVector<BaseFloat> mSparsityQCurrent;
      CuVector<BaseFloat> mVisMean; ///< buffer for mean visible

  };




  ////////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuRbmSparse::
  inline 
  CuRbmSparse::
  CuRbmSparse(size_t nInputs, size_t nOutputs, CuComponent *pPred)
    : CuRbmBase(nInputs, nOutputs, pPred), 
      mVisHid(nInputs,nOutputs), 
      mVisBias(nInputs), mHidBias(nOutputs),
      mVisHidCorrection(nInputs,nOutputs), 
      mVisBiasCorrection(nInputs), mHidBiasCorrection(nOutputs),
      mBackpropErrBuf(),
      mVisType(BERNOULLI),
      mHidType(BERNOULLI),

      mSparsityPrior(0.0001),
      mLambda(0.95),
      mSparsityCost(1e-7),
      mSparsityQ(nOutputs),
      mSparsityQCurrent(nOutputs),
      mVisMean(nInputs)
  { 
    mVisHidCorrection.SetConst(0.0);
    mVisBiasCorrection.SetConst(0.0);
    mHidBiasCorrection.SetConst(0.0);

    mSparsityQ.SetConst(mSparsityPrior);
    mSparsityQCurrent.SetConst(0.0);
    mVisMean.SetConst(0.0);
  }


  inline
  CuRbmSparse::
  ~CuRbmSparse()
  { }

  inline CuComponent::ComponentType
  CuRbmSparse::
  GetType() const
  {
    return CuComponent::RBM_SPARSE;
  }

  inline const char*
  CuRbmSparse::
  GetName() const
  {
    return "<rbmsparse>";
  }



} //namespace



#endif