summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuNetwork.h
blob: 86b5229e865ffb190aea05fdf1864e5ab7220a71 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#ifndef _CUNETWORK_H_
#define _CUNETWORK_H_

#include "cuComponent.h"

#include "cuBiasedLinearity.h"
//#include "cuBlockLinearity.h"
//#include "cuBias.h"
//#include "cuWindow.h"

#include "cuActivation.h"

#include "cuCRBEDctFeat.h"

#include "Vector.h"

#include <vector>

/**
 * \file cuNetwork.h
 * \brief CuNN manipulation class
 */

/// \defgroup CuNNComp CuNN Components

namespace TNet {
  /**
   * \brief Nural Network Manipulator & public interfaces
   *
   * \ingroup CuNNComp
   */
  class CuNetwork
  {
    //////////////////////////////////////
    // Typedefs
    typedef std::vector<CuComponent*> LayeredType;
      
      //////////////////////////////////////
      // Disable copy construction, assignment and default constructor
    private:
      CuNetwork(CuNetwork&); 
      CuNetwork& operator=(CuNetwork&);
       
    public:
      CuNetwork() { }
      CuNetwork(std::istream& rIn); 
      ~CuNetwork();

      void AddLayer(CuComponent* layer);

      int Layers()
      { return mNetComponents.size(); }

      CuComponent& Layer(int i)
      { return *mNetComponents[i]; }

      /// forward the data to the output
      void Propagate(CuMatrix<BaseFloat>& in, CuMatrix<BaseFloat>& out);

      /// backpropagate the error while updating weights
      void Backpropagate(CuMatrix<BaseFloat>& globerr); 

      void ReadNetwork(const char* pSrc);     ///< read the network from file
      void WriteNetwork(const char* pDst);    ///< write network to file

      void ReadNetwork(std::istream& rIn);    ///< read the network from stream
      void WriteNetwork(std::ostream& rOut);  ///< write network to stream

      size_t GetNInputs() const; ///< Dimensionality of the input features
      size_t GetNOutputs() const; ///< Dimensionality of the desired vectors

      /// set the learning rate
      void SetLearnRate(BaseFloat learnRate, const char* pLearnRateFactors = NULL); 
      BaseFloat GetLearnRate();  ///< get the learning rate value
      void PrintLearnRate();     ///< log the learning rate values

      void SetMomentum(BaseFloat momentum);
      void SetWeightcost(BaseFloat weightcost);
      void SetL1(BaseFloat l1);

      void SetGradDivFrm(bool div);
      
      /// Reads a component from a stream
      static CuComponent* ComponentReader(std::istream& rIn, CuComponent* pPred);
      /// Dumps component into a stream
      static void ComponentDumper(std::ostream& rOut, CuComponent& rComp);

      
    private:
      /// Creates a component by reading from stream
      CuComponent* ComponentFactory(std::istream& In);


    private:
      LayeredType mNetComponents; ///< container with the network layers
      CuComponent* mpPropagErrorStopper;
      BaseFloat mGlobLearnRate; ///< The global (unscaled) learn rate of the network
      const char* mpLearnRateFactors; ///< The global (unscaled) learn rate of the network
      

    //friend class NetworkGenerator; //<< For generating networks...

  };
    
  //////////////////////////////////////////////////////////////////////////
  // INLINE FUNCTIONS 
  // CuNetwork::
  inline 
  CuNetwork::
  CuNetwork(std::istream& rSource)
    : mpPropagErrorStopper(NULL), mGlobLearnRate(0.0), mpLearnRateFactors(NULL)
  {
    ReadNetwork(rSource);
  }


  inline
  CuNetwork::
  ~CuNetwork()
  {
    //delete all the components
    LayeredType::iterator it;
    for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) {
      delete *it;
      *it = NULL;
    }
    mNetComponents.resize(0);
  }

  
  inline void 
  CuNetwork::
  AddLayer(CuComponent* layer)
  {
    if(mNetComponents.size() > 0) {
      if(GetNOutputs() != layer->GetNInputs()) {
        Error("Nonmatching dims");
      }
      layer->SetPrevious(mNetComponents.back());
      mNetComponents.back()->SetNext(layer);
    }
    mNetComponents.push_back(layer);
  }


  inline void
  CuNetwork::
  Propagate(CuMatrix<BaseFloat>& in, CuMatrix<BaseFloat>& out)
  {
    //empty network => copy input
    if(mNetComponents.size() == 0) { 
      out.CopyFrom(in); 
      return;
    }

    //check dims
    if(in.Cols() != GetNInputs()) {
      std::ostringstream os;
      os << "Nonmatching dims"
         << " data dim is: " << in.Cols() 
         << " network needs: " << GetNInputs();
      Error(os.str());
    }
    mNetComponents.front()->SetInput(in);
    
    //propagate
    LayeredType::iterator it;
    for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) {
      (*it)->Propagate();
    }

    //copy the output
    out.CopyFrom(mNetComponents.back()->GetOutput());
  }




  inline void 
  CuNetwork::
  Backpropagate(CuMatrix<BaseFloat>& globerr) 
  {
    mNetComponents.back()->SetErrorInput(globerr);

    // back-propagation
    LayeredType::reverse_iterator it;
    for(it=mNetComponents.rbegin(); it!=mNetComponents.rend(); ++it) {
      //stopper component does not propagate error (no updatable predecessors)
      std::cout<<"tick1"<<std::flush;
      if(*it != mpPropagErrorStopper) {
        //compute errors for preceding network components
        (*it)->Backpropagate();
      }
      std::cout<<"tick2"<<std::flush;
      //update weights if updatable component
      if((*it)->IsUpdatable()) {
        CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(**it); 
        if(rComp.LearnRate() > 0.0f) {
          rComp.Update();
        }
      }
      std::cout<<"tick3"<<std::flush;
      //stop backprop if no updatable components precede current component
      if(mpPropagErrorStopper == *it) break;
    }
  }

      
  inline size_t
  CuNetwork::
  GetNInputs() const
  {
    if(!mNetComponents.size() > 0) return 0;
    return mNetComponents.front()->GetNInputs();
  }


  inline size_t
  CuNetwork::
  GetNOutputs() const
  {
    if(!mNetComponents.size() > 0) return 0;
    return mNetComponents.back()->GetNOutputs();
  }

} //namespace

#endif