diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuNetwork.h | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/CuTNetLib/cuNetwork.h')
-rw-r--r-- | src/CuTNetLib/cuNetwork.h | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuNetwork.h b/src/CuTNetLib/cuNetwork.h new file mode 100644 index 0000000..05e0ecb --- /dev/null +++ b/src/CuTNetLib/cuNetwork.h @@ -0,0 +1,227 @@ +#ifndef _CUNETWORK_H_ +#define _CUNETWORK_H_ + +#include "cuComponent.h" + +#include "cuBiasedLinearity.h" +//#include "cuBlockLinearity.h" +//#include "cuBias.h" +//#include "cuWindow.h" + +#include "cuActivation.h" + +#include "cuCRBEDctFeat.h" + +#include "Vector.h" + +#include <vector> + +/** + * \file cuNetwork.h + * \brief CuNN manipulation class + */ + +/// \defgroup CuNNComp CuNN Components + +namespace TNet { + /** + * \brief Nural Network Manipulator & public interfaces + * + * \ingroup CuNNComp + */ + class CuNetwork + { + ////////////////////////////////////// + // Typedefs + typedef std::vector<CuComponent*> LayeredType; + + ////////////////////////////////////// + // Disable copy construction, assignment and default constructor + private: + CuNetwork(CuNetwork&); + CuNetwork& operator=(CuNetwork&); + + public: + CuNetwork() { } + CuNetwork(std::istream& rIn); + ~CuNetwork(); + + void AddLayer(CuComponent* layer); + + int Layers() + { return mNetComponents.size(); } + + CuComponent& Layer(int i) + { return *mNetComponents[i]; } + + /// forward the data to the output + void Propagate(CuMatrix<BaseFloat>& in, CuMatrix<BaseFloat>& out); + + /// backpropagate the error while updating weights + void Backpropagate(CuMatrix<BaseFloat>& globerr); + + void ReadNetwork(const char* pSrc); ///< read the network from file + void WriteNetwork(const char* pDst); ///< write network to file + + void ReadNetwork(std::istream& rIn); ///< read the network from stream + void WriteNetwork(std::ostream& rOut); ///< write network to stream + + size_t GetNInputs() const; ///< Dimensionality of the input features + size_t GetNOutputs() const; ///< Dimensionality of the desired vectors + + /// set the learning rate + void SetLearnRate(BaseFloat learnRate, const char* pLearnRateFactors = NULL); + BaseFloat GetLearnRate(); ///< get the learning rate value + void PrintLearnRate(); ///< log the learning rate values + + void SetMomentum(BaseFloat momentum); + void SetWeightcost(BaseFloat weightcost); + void SetL1(BaseFloat l1); + + void SetGradDivFrm(bool div); + + /// Reads a component from a stream + static CuComponent* ComponentReader(std::istream& rIn, CuComponent* pPred); + /// Dumps component into a stream + static void ComponentDumper(std::ostream& rOut, CuComponent& rComp); + + + private: + /// Creates a component by reading from stream + CuComponent* ComponentFactory(std::istream& In); + + + private: + LayeredType mNetComponents; ///< container with the network layers + CuComponent* mpPropagErrorStopper; + BaseFloat mGlobLearnRate; ///< The global (unscaled) learn rate of the network + const char* mpLearnRateFactors; ///< The global (unscaled) learn rate of the network + + + //friend class NetworkGenerator; //<< For generating networks... + + }; + + ////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // CuNetwork:: + inline + CuNetwork:: + CuNetwork(std::istream& rSource) + : mpPropagErrorStopper(NULL), mGlobLearnRate(0.0), mpLearnRateFactors(NULL) + { + ReadNetwork(rSource); + } + + + inline + CuNetwork:: + ~CuNetwork() + { + //delete all the components + LayeredType::iterator it; + for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { + delete *it; + *it = NULL; + } + mNetComponents.resize(0); + } + + + inline void + CuNetwork:: + AddLayer(CuComponent* layer) + { + if(mNetComponents.size() > 0) { + if(GetNOutputs() != layer->GetNInputs()) { + Error("Nonmatching dims"); + } + layer->SetPrevious(mNetComponents.back()); + mNetComponents.back()->SetNext(layer); + } + mNetComponents.push_back(layer); + } + + + inline void + CuNetwork:: + Propagate(CuMatrix<BaseFloat>& in, CuMatrix<BaseFloat>& out) + { + //empty network => copy input + if(mNetComponents.size() == 0) { + out.CopyFrom(in); + return; + } + + //check dims + if(in.Cols() != GetNInputs()) { + std::ostringstream os; + os << "Nonmatching dims" + << " data dim is: " << in.Cols() + << " network needs: " << GetNInputs(); + Error(os.str()); + } + mNetComponents.front()->SetInput(in); + + //propagate + LayeredType::iterator it; + for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { + (*it)->Propagate(); + } + + //copy the output + out.CopyFrom(mNetComponents.back()->GetOutput()); + } + + + + + inline void + CuNetwork:: + Backpropagate(CuMatrix<BaseFloat>& globerr) + { + mNetComponents.back()->SetErrorInput(globerr); + + // back-propagation + LayeredType::reverse_iterator it; + for(it=mNetComponents.rbegin(); it!=mNetComponents.rend(); ++it) { + //stopper component does not propagate error (no updatable predecessors) + if(*it != mpPropagErrorStopper) { + //compute errors for preceding network components + (*it)->Backpropagate(); + } + //update weights if updatable component + if((*it)->IsUpdatable()) { + CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(**it); + if(rComp.LearnRate() > 0.0f) { + rComp.Update(); + } + } + //stop backprop if no updatable components precede current component + if(mpPropagErrorStopper == *it) break; + } + } + + + inline size_t + CuNetwork:: + GetNInputs() const + { + if(!mNetComponents.size() > 0) return 0; + return mNetComponents.front()->GetNInputs(); + } + + + inline size_t + CuNetwork:: + GetNOutputs() const + { + if(!mNetComponents.size() > 0) return 0; + return mNetComponents.back()->GetNOutputs(); + } + +} //namespace + +#endif + + |