#ifndef _CULINEARITY_H_
#define _CULINEARITY_H_


#include "cuComponent.h"
#include "cumatrix.h"


#include "Matrix.h"
#include "Vector.h"


namespace TNet {
  /**
   * \brief CuLinearity summation function
   *
   * \ingroup CuNNUpdatable
   * Implements forward pass: \f[ Y_j=\Sigma_{i=0}^{i=N-1}w_ij X_i +{\beta}_j \f]
   * Error propagation: \f[ E_i = \Sigma_{i=0}^{i=N-1} w_ij e_j \f]
   *
   * Weight adjustion: \f[ W_{ij} = (1-D)(w_{ij} - \alpha(1-\mu)x_i e_j - \mu \Delta) \f]
   * where
   *  - D for weight decay => penalizing large weight
   *  - \f$ \alpha \f$ for learning rate
   *  - \f$ \mu \f$ for momentum => avoiding oscillation
   */
  class CuLinearity : public CuUpdatableComponent
  {
    public:

      CuLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      ~CuLinearity();  
      
      ComponentType GetType() const;
      const char* GetName() const;

      void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
      void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);

      void Update();

      void ReadFromStream(std::istream& rIn);
      void WriteToStream(std::ostream& rOut);

    protected:
      CuMatrix<BaseFloat> mLinearity;  ///< Matrix with neuron weights

      CuMatrix<BaseFloat> mLinearityCorrection; ///< Matrix for linearity updates

  };




  ////////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuLinearity::
  inline 
  CuLinearity::
  CuLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred)
    : CuUpdatableComponent(nInputs, nOutputs, pPred), 
      mLinearity(nInputs,nOutputs), 
      mLinearityCorrection(nInputs,nOutputs) 
  { 
    mLinearityCorrection.SetConst(0.0);
  }


  inline
  CuLinearity::
  ~CuLinearity()
  { }

  inline CuComponent::ComponentType
  CuLinearity::
  GetType() const
  {
    return CuComponent::LINEARITY;
  }

  inline const char*
  CuLinearity::
  GetName() const
  {
    return "<linearity>";
  }



} //namespace



#endif