#ifndef _CUNETWORK_H_ #define _CUNETWORK_H_ #include "cuComponent.h" #include "cuBiasedLinearity.h" //#include "cuBlockLinearity.h" //#include "cuBias.h" //#include "cuWindow.h" #include "cuActivation.h" #include "cuCRBEDctFeat.h" #include "Vector.h" #include /** * \file cuNetwork.h * \brief CuNN manipulation class */ /// \defgroup CuNNComp CuNN Components namespace TNet { /** * \brief Nural Network Manipulator & public interfaces * * \ingroup CuNNComp */ class CuNetwork { ////////////////////////////////////// // Typedefs typedef std::vector LayeredType; ////////////////////////////////////// // Disable copy construction, assignment and default constructor private: CuNetwork(CuNetwork&); CuNetwork& operator=(CuNetwork&); public: CuNetwork() { } CuNetwork(std::istream& rIn); ~CuNetwork(); void AddLayer(CuComponent* layer); int Layers() { return mNetComponents.size(); } CuComponent& Layer(int i) { return *mNetComponents[i]; } /// forward the data to the output void Propagate(CuMatrix& in, CuMatrix& out); /// backpropagate the error while updating weights void Backpropagate(CuMatrix& globerr); void ReadNetwork(const char* pSrc); ///< read the network from file void WriteNetwork(const char* pDst); ///< write network to file void ReadNetwork(std::istream& rIn); ///< read the network from stream void WriteNetwork(std::ostream& rOut); ///< write network to stream size_t GetNInputs() const; ///< Dimensionality of the input features size_t GetNOutputs() const; ///< Dimensionality of the desired vectors /// set the learning rate void SetLearnRate(BaseFloat learnRate, const char* pLearnRateFactors = NULL); BaseFloat GetLearnRate(); ///< get the learning rate value void PrintLearnRate(); ///< log the learning rate values void SetMomentum(BaseFloat momentum); void SetWeightcost(BaseFloat weightcost); void SetL1(BaseFloat l1); void SetGradDivFrm(bool div); /// Reads a component from a stream static CuComponent* ComponentReader(std::istream& rIn, CuComponent* pPred); /// Dumps component into a stream static void ComponentDumper(std::ostream& rOut, CuComponent& rComp); private: /// Creates a component by reading from stream CuComponent* ComponentFactory(std::istream& In); private: LayeredType mNetComponents; ///< container with the network layers CuComponent* mpPropagErrorStopper; BaseFloat mGlobLearnRate; ///< The global (unscaled) learn rate of the network const char* mpLearnRateFactors; ///< The global (unscaled) learn rate of the network //friend class NetworkGenerator; //<< For generating networks... }; ////////////////////////////////////////////////////////////////////////// // INLINE FUNCTIONS // CuNetwork:: inline CuNetwork:: CuNetwork(std::istream& rSource) : mpPropagErrorStopper(NULL), mGlobLearnRate(0.0), mpLearnRateFactors(NULL) { ReadNetwork(rSource); } inline CuNetwork:: ~CuNetwork() { //delete all the components LayeredType::iterator it; for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { delete *it; *it = NULL; } mNetComponents.resize(0); } inline void CuNetwork:: AddLayer(CuComponent* layer) { if(mNetComponents.size() > 0) { if(GetNOutputs() != layer->GetNInputs()) { Error("Nonmatching dims"); } layer->SetPrevious(mNetComponents.back()); mNetComponents.back()->SetNext(layer); } mNetComponents.push_back(layer); } inline void CuNetwork:: Propagate(CuMatrix& in, CuMatrix& out) { //empty network => copy input if(mNetComponents.size() == 0) { out.CopyFrom(in); return; } //check dims if(in.Cols() != GetNInputs()) { std::ostringstream os; os << "Nonmatching dims" << " data dim is: " << in.Cols() << " network needs: " << GetNInputs(); Error(os.str()); } mNetComponents.front()->SetInput(in); //propagate LayeredType::iterator it; for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { (*it)->Propagate(); } //copy the output out.CopyFrom(mNetComponents.back()->GetOutput()); } inline void CuNetwork:: Backpropagate(CuMatrix& globerr) { mNetComponents.back()->SetErrorInput(globerr); // back-propagation LayeredType::reverse_iterator it; for(it=mNetComponents.rbegin(); it!=mNetComponents.rend(); ++it) { //stopper component does not propagate error (no updatable predecessors) if(*it != mpPropagErrorStopper) { //compute errors for preceding network components (*it)->Backpropagate(); } //update weights if updatable component if((*it)->IsUpdatable()) { CuUpdatableComponent& rComp = dynamic_cast(**it); if(rComp.LearnRate() > 0.0f) { rComp.Update(); } } //stop backprop if no updatable components precede current component if(mpPropagErrorStopper == *it) break; } } inline size_t CuNetwork:: GetNInputs() const { if(!mNetComponents.size() > 0) return 0; return mNetComponents.front()->GetNInputs(); } inline size_t CuNetwork:: GetNOutputs() const { if(!mNetComponents.size() > 0) return 0; return mNetComponents.back()->GetNOutputs(); } } //namespace #endif