summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuNetwork.h
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
commitcccccbf6cca94a3eaf813b4468453160e91c332b (patch)
tree23418cb73a10ae3b0688681a7f0ba9b06424583e /src/CuTNetLib/cuNetwork.h
downloadtnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip
First commit
Diffstat (limited to 'src/CuTNetLib/cuNetwork.h')
-rw-r--r--src/CuTNetLib/cuNetwork.h227
1 files changed, 227 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuNetwork.h b/src/CuTNetLib/cuNetwork.h
new file mode 100644
index 0000000..05e0ecb
--- /dev/null
+++ b/src/CuTNetLib/cuNetwork.h
@@ -0,0 +1,227 @@
+#ifndef _CUNETWORK_H_
+#define _CUNETWORK_H_
+
+#include "cuComponent.h"
+
+#include "cuBiasedLinearity.h"
+//#include "cuBlockLinearity.h"
+//#include "cuBias.h"
+//#include "cuWindow.h"
+
+#include "cuActivation.h"
+
+#include "cuCRBEDctFeat.h"
+
+#include "Vector.h"
+
+#include <vector>
+
+/**
+ * \file cuNetwork.h
+ * \brief CuNN manipulation class
+ */
+
+/// \defgroup CuNNComp CuNN Components
+
+namespace TNet {
+ /**
+ * \brief Nural Network Manipulator & public interfaces
+ *
+ * \ingroup CuNNComp
+ */
+ class CuNetwork
+ {
+ //////////////////////////////////////
+ // Typedefs
+ typedef std::vector<CuComponent*> LayeredType;
+
+ //////////////////////////////////////
+ // Disable copy construction, assignment and default constructor
+ private:
+ CuNetwork(CuNetwork&);
+ CuNetwork& operator=(CuNetwork&);
+
+ public:
+ CuNetwork() { }
+ CuNetwork(std::istream& rIn);
+ ~CuNetwork();
+
+ void AddLayer(CuComponent* layer);
+
+ int Layers()
+ { return mNetComponents.size(); }
+
+ CuComponent& Layer(int i)
+ { return *mNetComponents[i]; }
+
+ /// forward the data to the output
+ void Propagate(CuMatrix<BaseFloat>& in, CuMatrix<BaseFloat>& out);
+
+ /// backpropagate the error while updating weights
+ void Backpropagate(CuMatrix<BaseFloat>& globerr);
+
+ void ReadNetwork(const char* pSrc); ///< read the network from file
+ void WriteNetwork(const char* pDst); ///< write network to file
+
+ void ReadNetwork(std::istream& rIn); ///< read the network from stream
+ void WriteNetwork(std::ostream& rOut); ///< write network to stream
+
+ size_t GetNInputs() const; ///< Dimensionality of the input features
+ size_t GetNOutputs() const; ///< Dimensionality of the desired vectors
+
+ /// set the learning rate
+ void SetLearnRate(BaseFloat learnRate, const char* pLearnRateFactors = NULL);
+ BaseFloat GetLearnRate(); ///< get the learning rate value
+ void PrintLearnRate(); ///< log the learning rate values
+
+ void SetMomentum(BaseFloat momentum);
+ void SetWeightcost(BaseFloat weightcost);
+ void SetL1(BaseFloat l1);
+
+ void SetGradDivFrm(bool div);
+
+ /// Reads a component from a stream
+ static CuComponent* ComponentReader(std::istream& rIn, CuComponent* pPred);
+ /// Dumps component into a stream
+ static void ComponentDumper(std::ostream& rOut, CuComponent& rComp);
+
+
+ private:
+ /// Creates a component by reading from stream
+ CuComponent* ComponentFactory(std::istream& In);
+
+
+ private:
+ LayeredType mNetComponents; ///< container with the network layers
+ CuComponent* mpPropagErrorStopper;
+ BaseFloat mGlobLearnRate; ///< The global (unscaled) learn rate of the network
+ const char* mpLearnRateFactors; ///< The global (unscaled) learn rate of the network
+
+
+ //friend class NetworkGenerator; //<< For generating networks...
+
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ // INLINE FUNCTIONS
+ // CuNetwork::
+ inline
+ CuNetwork::
+ CuNetwork(std::istream& rSource)
+ : mpPropagErrorStopper(NULL), mGlobLearnRate(0.0), mpLearnRateFactors(NULL)
+ {
+ ReadNetwork(rSource);
+ }
+
+
+ inline
+ CuNetwork::
+ ~CuNetwork()
+ {
+ //delete all the components
+ LayeredType::iterator it;
+ for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) {
+ delete *it;
+ *it = NULL;
+ }
+ mNetComponents.resize(0);
+ }
+
+
+ inline void
+ CuNetwork::
+ AddLayer(CuComponent* layer)
+ {
+ if(mNetComponents.size() > 0) {
+ if(GetNOutputs() != layer->GetNInputs()) {
+ Error("Nonmatching dims");
+ }
+ layer->SetPrevious(mNetComponents.back());
+ mNetComponents.back()->SetNext(layer);
+ }
+ mNetComponents.push_back(layer);
+ }
+
+
+ inline void
+ CuNetwork::
+ Propagate(CuMatrix<BaseFloat>& in, CuMatrix<BaseFloat>& out)
+ {
+ //empty network => copy input
+ if(mNetComponents.size() == 0) {
+ out.CopyFrom(in);
+ return;
+ }
+
+ //check dims
+ if(in.Cols() != GetNInputs()) {
+ std::ostringstream os;
+ os << "Nonmatching dims"
+ << " data dim is: " << in.Cols()
+ << " network needs: " << GetNInputs();
+ Error(os.str());
+ }
+ mNetComponents.front()->SetInput(in);
+
+ //propagate
+ LayeredType::iterator it;
+ for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) {
+ (*it)->Propagate();
+ }
+
+ //copy the output
+ out.CopyFrom(mNetComponents.back()->GetOutput());
+ }
+
+
+
+
+ inline void
+ CuNetwork::
+ Backpropagate(CuMatrix<BaseFloat>& globerr)
+ {
+ mNetComponents.back()->SetErrorInput(globerr);
+
+ // back-propagation
+ LayeredType::reverse_iterator it;
+ for(it=mNetComponents.rbegin(); it!=mNetComponents.rend(); ++it) {
+ //stopper component does not propagate error (no updatable predecessors)
+ if(*it != mpPropagErrorStopper) {
+ //compute errors for preceding network components
+ (*it)->Backpropagate();
+ }
+ //update weights if updatable component
+ if((*it)->IsUpdatable()) {
+ CuUpdatableComponent& rComp = dynamic_cast<CuUpdatableComponent&>(**it);
+ if(rComp.LearnRate() > 0.0f) {
+ rComp.Update();
+ }
+ }
+ //stop backprop if no updatable components precede current component
+ if(mpPropagErrorStopper == *it) break;
+ }
+ }
+
+
+ inline size_t
+ CuNetwork::
+ GetNInputs() const
+ {
+ if(!mNetComponents.size() > 0) return 0;
+ return mNetComponents.front()->GetNInputs();
+ }
+
+
+ inline size_t
+ CuNetwork::
+ GetNOutputs() const
+ {
+ if(!mNetComponents.size() > 0) return 0;
+ return mNetComponents.back()->GetNOutputs();
+ }
+
+} //namespace
+
+#endif
+
+