summaryrefslogtreecommitdiff
path: root/src/TNetLib/Nnet.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/TNetLib/Nnet.h')
-rw-r--r--src/TNetLib/Nnet.h194
1 files changed, 194 insertions, 0 deletions
diff --git a/src/TNetLib/Nnet.h b/src/TNetLib/Nnet.h
new file mode 100644
index 0000000..12e2585
--- /dev/null
+++ b/src/TNetLib/Nnet.h
@@ -0,0 +1,194 @@
+#ifndef _NETWORK_H_
+#define _NETWORK_H_
+
+#include "Component.h"
+#include "BiasedLinearity.h"
+#include "SharedLinearity.h"
+#include "Activation.h"
+
+#include "Vector.h"
+
+#include <vector>
+
+
+namespace TNet {
+
+class Network
+{
+//////////////////////////////////////
+// Typedefs
+typedef std::vector<Component*> LayeredType;
+
+ //////////////////////////////////////
+ // Disable copy construction and assignment
+ private:
+ Network(Network&);
+ Network& operator=(Network&);
+
+ public:
+ // allow incomplete network creation
+ Network()
+ { }
+
+ ~Network();
+
+ int Layers() const
+ { return mNnet.size(); }
+
+ Component& Layer(int i)
+ { return *mNnet[i]; }
+
+ const Component& Layer(int i) const
+ { return *mNnet[i]; }
+
+ /// Feedforward the data per blocks, this needs less memory,
+ /// and allows to process very long files.
+ /// It does not trim the *_frm_ext, but uses it
+ /// for concatenation of segments
+ void Feedforward(const Matrix<BaseFloat>& in, Matrix<BaseFloat>& out,
+ size_t start_frm_ext, size_t end_frm_ext);
+ /// forward the data to the output
+ void Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>& out);
+ /// backpropagate the error while calculating the gradient
+ void Backpropagate(const Matrix<BaseFloat>& globerr);
+
+ /// accumulate the gradient from other networks
+ void AccuGradient(const Network& src, int thr, int thrN);
+ /// update weights, reset the accumulator
+ void Update(int thr, int thrN);
+
+ Network* Clone(); ///< Clones the network
+
+ void ReadNetwork(const char* pSrc); ///< read the network from file
+ void ReadNetwork(std::istream& rIn); ///< read the network from stream
+ void WriteNetwork(const char* pDst); ///< write network to file
+ void WriteNetwork(std::ostream& rOut); ///< write network to stream
+
+ size_t GetNInputs() const; ///< Dimensionality of the input features
+ size_t GetNOutputs() const; ///< Dimensionality of the desired vectors
+
+ void SetLearnRate(BaseFloat learnRate); ///< set the learning rate value
+ BaseFloat GetLearnRate(); ///< get the learning rate value
+
+ void SetWeightcost(BaseFloat l2); ///< set the L2 regularization const
+
+ void ResetBunchsize(); ///< reset the frame counter (needed for L2 regularization
+ void AccuBunchsize(const Network& src); ///< accumulate frame counts in bunch (needed in L2 regularization
+
+ private:
+ /// Creates a component by reading from stream
+ Component* ComponentFactory(std::istream& In);
+ /// Dumps component into a stream
+ void ComponentDumper(std::ostream& rOut, Component& rComp);
+
+ private:
+ LayeredType mNnet; ///< container with the network layers
+
+};
+
+
+//////////////////////////////////////////////////////////////////////////
+// INLINE FUNCTIONS
+// Network::
+inline Network::~Network() {
+ //delete all the components
+ LayeredType::iterator it;
+ for(it=mNnet.begin(); it!=mNnet.end(); ++it) {
+ delete *it;
+ }
+}
+
+
+inline size_t Network::GetNInputs() const {
+ assert(mNnet.size() > 0);
+ return mNnet.front()->GetNInputs();
+}
+
+
+inline size_t
+Network::
+GetNOutputs() const
+{
+ assert(mNnet.size() > 0);
+ return mNnet.back()->GetNOutputs();
+}
+
+
+
+inline void
+Network::
+SetLearnRate(BaseFloat learnRate)
+{
+ LayeredType::iterator it;
+ for(it=mNnet.begin(); it!=mNnet.end(); ++it) {
+ if((*it)->IsUpdatable()) {
+ dynamic_cast<UpdatableComponent*>(*it)->LearnRate(learnRate);
+ }
+ }
+}
+
+
+inline BaseFloat
+Network::
+GetLearnRate()
+{
+ //TODO - learn rates may differ layer to layer
+ assert(mNnet.size() > 0);
+ for(size_t i=0; i<mNnet.size(); i++) {
+ if(mNnet[i]->IsUpdatable()) {
+ return dynamic_cast<UpdatableComponent*>(mNnet[i])->LearnRate();
+ }
+ }
+ Error("No updatable NetComponents");
+ return -1;
+}
+
+
+inline void
+Network::
+SetWeightcost(BaseFloat l2)
+{
+ LayeredType::iterator it;
+ for(it=mNnet.begin(); it!=mNnet.end(); ++it) {
+ if((*it)->IsUpdatable()) {
+ dynamic_cast<UpdatableComponent*>(*it)->Weightcost(l2);
+ }
+ }
+}
+
+
+inline void
+Network::
+ResetBunchsize()
+{
+ LayeredType::iterator it;
+ for(it=mNnet.begin(); it!=mNnet.end(); ++it) {
+ if((*it)->IsUpdatable()) {
+ dynamic_cast<UpdatableComponent*>(*it)->Bunchsize(0);
+ }
+ }
+}
+
+inline void
+Network::
+AccuBunchsize(const Network& src)
+{
+ assert(Layers() == src.Layers());
+ assert(Layers() > 0);
+
+ for(int i=0; i<Layers(); i++) {
+ if(Layer(i).IsUpdatable()) {
+ UpdatableComponent& tgt_comp = dynamic_cast<UpdatableComponent&>(Layer(i));
+ const UpdatableComponent& src_comp = dynamic_cast<const UpdatableComponent&>(src.Layer(i));
+ tgt_comp.Bunchsize(tgt_comp.Bunchsize()+src_comp.GetOutput().Rows());
+ }
+ }
+}
+
+
+
+} //namespace
+
+#endif
+
+