From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- src/TNetLib/.svn/text-base/Nnet.h.svn-base | 194 +++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 src/TNetLib/.svn/text-base/Nnet.h.svn-base (limited to 'src/TNetLib/.svn/text-base/Nnet.h.svn-base') diff --git a/src/TNetLib/.svn/text-base/Nnet.h.svn-base b/src/TNetLib/.svn/text-base/Nnet.h.svn-base new file mode 100644 index 0000000..12e2585 --- /dev/null +++ b/src/TNetLib/.svn/text-base/Nnet.h.svn-base @@ -0,0 +1,194 @@ +#ifndef _NETWORK_H_ +#define _NETWORK_H_ + +#include "Component.h" +#include "BiasedLinearity.h" +#include "SharedLinearity.h" +#include "Activation.h" + +#include "Vector.h" + +#include + + +namespace TNet { + +class Network +{ +////////////////////////////////////// +// Typedefs +typedef std::vector LayeredType; + + ////////////////////////////////////// + // Disable copy construction and assignment + private: + Network(Network&); + Network& operator=(Network&); + + public: + // allow incomplete network creation + Network() + { } + + ~Network(); + + int Layers() const + { return mNnet.size(); } + + Component& Layer(int i) + { return *mNnet[i]; } + + const Component& Layer(int i) const + { return *mNnet[i]; } + + /// Feedforward the data per blocks, this needs less memory, + /// and allows to process very long files. + /// It does not trim the *_frm_ext, but uses it + /// for concatenation of segments + void Feedforward(const Matrix& in, Matrix& out, + size_t start_frm_ext, size_t end_frm_ext); + /// forward the data to the output + void Propagate(const Matrix& in, Matrix& out); + /// backpropagate the error while calculating the gradient + void Backpropagate(const Matrix& globerr); + + /// accumulate the gradient from other networks + void AccuGradient(const Network& src, int thr, int thrN); + /// update weights, reset the accumulator + void Update(int thr, int thrN); + + Network* Clone(); ///< Clones the network + + void ReadNetwork(const char* pSrc); ///< read the network from file + void ReadNetwork(std::istream& rIn); ///< read the network from stream + void WriteNetwork(const char* pDst); ///< write network to file + void WriteNetwork(std::ostream& rOut); ///< write network to stream + + size_t GetNInputs() const; ///< Dimensionality of the input features + size_t GetNOutputs() const; ///< Dimensionality of the desired vectors + + void SetLearnRate(BaseFloat learnRate); ///< set the learning rate value + BaseFloat GetLearnRate(); ///< get the learning rate value + + void SetWeightcost(BaseFloat l2); ///< set the L2 regularization const + + void ResetBunchsize(); ///< reset the frame counter (needed for L2 regularization + void AccuBunchsize(const Network& src); ///< accumulate frame counts in bunch (needed in L2 regularization + + private: + /// Creates a component by reading from stream + Component* ComponentFactory(std::istream& In); + /// Dumps component into a stream + void ComponentDumper(std::ostream& rOut, Component& rComp); + + private: + LayeredType mNnet; ///< container with the network layers + +}; + + +////////////////////////////////////////////////////////////////////////// +// INLINE FUNCTIONS +// Network:: +inline Network::~Network() { + //delete all the components + LayeredType::iterator it; + for(it=mNnet.begin(); it!=mNnet.end(); ++it) { + delete *it; + } +} + + +inline size_t Network::GetNInputs() const { + assert(mNnet.size() > 0); + return mNnet.front()->GetNInputs(); +} + + +inline size_t +Network:: +GetNOutputs() const +{ + assert(mNnet.size() > 0); + return mNnet.back()->GetNOutputs(); +} + + + +inline void +Network:: +SetLearnRate(BaseFloat learnRate) +{ + LayeredType::iterator it; + for(it=mNnet.begin(); it!=mNnet.end(); ++it) { + if((*it)->IsUpdatable()) { + dynamic_cast(*it)->LearnRate(learnRate); + } + } +} + + +inline BaseFloat +Network:: +GetLearnRate() +{ + //TODO - learn rates may differ layer to layer + assert(mNnet.size() > 0); + for(size_t i=0; iIsUpdatable()) { + return dynamic_cast(mNnet[i])->LearnRate(); + } + } + Error("No updatable NetComponents"); + return -1; +} + + +inline void +Network:: +SetWeightcost(BaseFloat l2) +{ + LayeredType::iterator it; + for(it=mNnet.begin(); it!=mNnet.end(); ++it) { + if((*it)->IsUpdatable()) { + dynamic_cast(*it)->Weightcost(l2); + } + } +} + + +inline void +Network:: +ResetBunchsize() +{ + LayeredType::iterator it; + for(it=mNnet.begin(); it!=mNnet.end(); ++it) { + if((*it)->IsUpdatable()) { + dynamic_cast(*it)->Bunchsize(0); + } + } +} + +inline void +Network:: +AccuBunchsize(const Network& src) +{ + assert(Layers() == src.Layers()); + assert(Layers() > 0); + + for(int i=0; i(Layer(i)); + const UpdatableComponent& src_comp = dynamic_cast(src.Layer(i)); + tgt_comp.Bunchsize(tgt_comp.Bunchsize()+src_comp.GetOutput().Rows()); + } + } +} + + + +} //namespace + +#endif + + -- cgit v1.2.3-70-g09d2