summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuSparseLinearity.h
blob: 3cdf07879e724771d985e8b07b143021c5572578 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#ifndef _CUSPARSE_LINEARITY_H_
#define _CUSPARSE_LINEARITY_H_


#include "cuComponent.h"
#include "cumatrix.h"


#include "Matrix.h"
#include "Vector.h"


namespace TNet {

  /**
   * \brief CuSparseLinearity summation function
   *
   * \ingroup CuNNUpdatable
   * Using weight masks to avoid fluctuation in the output
   *  -Weights are masked when it is lower than certain threshold - mSparsifyWeightThreshold
   *  -Weights are activated when the accumulated change is larger than certan value - mUnsparsifyAccu
   *  -L1 lasso function zeroing weights
   *  .
   * \sa CuBiasedLinearity
   */
  class CuSparseLinearity : public CuUpdatableComponent
  {
    public:

      CuSparseLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred); 
      ~CuSparseLinearity();  
      
      ComponentType GetType() const;
      const char* GetName() const;

      void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
      void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);

      void Update();
      void UpdateMask();

      void ReadFromStream(std::istream& rIn);
      void WriteToStream(std::ostream& rOut);

      void L1(BaseFloat l1) {
        mL1Const = l1;
      }

    protected:
      CuMatrix<BaseFloat> mLinearity;  ///< Matrix with neuron weights
      CuVector<BaseFloat> mBias;       ///< Vector with biases
      CuMatrix<BaseFloat> mSparsityMask; ///< Mask which selects active weights

      CuMatrix<BaseFloat> mLinearityCorrection; ///< Matrix for linearity updates
      CuVector<BaseFloat> mBiasCorrection;      ///< Vector for bias updates

      CuMatrix<BaseFloat> mLinearityCorrectionAccu; ///< Accumulator for linearity updates

      BaseFloat mL1Const; ///< L1 regularization constant

      size_t mNFrames; ///< Number of accumulated frames 
      BaseFloat mSparsifyWeightThreshold; ///< Cutoff
      BaseFloat mUnsparsifyAccu; ///< Threshold to unsparsify the Cutoff

      
  };




  ////////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuSparseLinearity::
  inline 
  CuSparseLinearity::
  CuSparseLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred)
    : CuUpdatableComponent(nInputs, nOutputs, pPred), 
      mLinearity(nInputs,nOutputs), mBias(nOutputs), mSparsityMask(nInputs,nOutputs),
      mLinearityCorrection(nInputs,nOutputs), mBiasCorrection(nOutputs),
      mLinearityCorrectionAccu(nInputs,nOutputs),
      mNFrames(0), mSparsifyWeightThreshold(1.0e-3),
      mUnsparsifyAccu(1e20f)
  { 
    mLinearityCorrection.SetConst(0.0f);
    mBiasCorrection.SetConst(0.0f);
    mLinearityCorrectionAccu.SetConst(0.0f);
  }


  inline
  CuSparseLinearity::
  ~CuSparseLinearity()
  { }

  inline CuComponent::ComponentType
  CuSparseLinearity::
  GetType() const
  {
    return CuComponent::SPARSE_LINEARITY;
  }

  inline const char*
  CuSparseLinearity::
  GetName() const
  {
    return "<sparselinearity>";
  }



} //namespace



#endif