summaryrefslogtreecommitdiff
path: root/src/TNetLib/Nnet.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/TNetLib/Nnet.cc')
-rw-r--r--src/TNetLib/Nnet.cc360
1 files changed, 360 insertions, 0 deletions
diff --git a/src/TNetLib/Nnet.cc b/src/TNetLib/Nnet.cc
new file mode 100644
index 0000000..4b364ac
--- /dev/null
+++ b/src/TNetLib/Nnet.cc
@@ -0,0 +1,360 @@
+
+#include <algorithm>
+//#include <locale>
+#include <cctype>
+
+#include "Nnet.h"
+#include "CRBEDctFeat.h"
+#include "BlockArray.h"
+
+namespace TNet {
+
+
+
+
+void Network::Feedforward(const Matrix<BaseFloat>& in, Matrix<BaseFloat>& out,
+ size_t start_frm_ext, size_t end_frm_ext) {
+ //empty network: copy input to output
+ if(mNnet.size() == 0) {
+ if(out.Rows() != in.Rows() || out.Cols() != in.Cols()) {
+ out.Init(in.Rows(),in.Cols());
+ }
+ out.Copy(in);
+ return;
+ }
+
+ //short input: propagate in one block
+ if(in.Rows() < 5000) {
+ Propagate(in,out);
+ } else {//long input: propagate per parts
+ //initialize
+ out.Init(in.Rows(),GetNOutputs());
+ Matrix<BaseFloat> tmp_in, tmp_out;
+ int done=0, block=1024;
+ //propagate first part
+ tmp_in.Init(block+end_frm_ext,in.Cols());
+ tmp_in.Copy(in.Range(0,block+end_frm_ext,0,in.Cols()));
+ Propagate(tmp_in,tmp_out);
+ out.Range(0,block,0,tmp_out.Cols()).Copy(
+ tmp_out.Range(0,block,0,tmp_out.Cols())
+ );
+ done += block;
+ //propagate middle parts
+ while((done+2*block) < in.Rows()) {
+ tmp_in.Init(block+start_frm_ext+end_frm_ext,in.Cols());
+ tmp_in.Copy(in.Range(done-start_frm_ext, block+start_frm_ext+end_frm_ext, 0,in.Cols())); Propagate(tmp_in,tmp_out);
+ out.Range(done,block,0,tmp_out.Cols()).Copy(
+ tmp_out.Range(start_frm_ext,block,0,tmp_out.Cols())
+ );
+ done += block;
+ }
+ //propagate last part
+ tmp_in.Init(in.Rows()-done+start_frm_ext,in.Cols());
+ tmp_in.Copy(in.Range(done-start_frm_ext,in.Rows()-done+start_frm_ext,0,in.Cols()));
+ Propagate(tmp_in,tmp_out);
+ out.Range(done,out.Rows()-done,0,out.Cols()).Copy(
+ tmp_out.Range(start_frm_ext,tmp_out.Rows()-start_frm_ext,0,tmp_out.Cols())
+ );
+
+ done += tmp_out.Rows()-start_frm_ext;
+ assert(done == out.Rows());
+ }
+}
+
+
+void Network::Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>& out) {
+ //empty network: copy input to output
+ if(mNnet.size() == 0) {
+ if(out.Rows() != in.Rows() || out.Cols() != in.Cols()) {
+ out.Init(in.Rows(),in.Cols());
+ }
+ out.Copy(in);
+ return;
+ }
+
+ //this will keep pointer to matrix 'in', for backprop
+ mNnet.front()->SetInput(in);
+
+ //propagate
+ LayeredType::iterator it;
+ for(it=mNnet.begin(); it!=mNnet.end(); ++it) {
+ (*it)->Propagate();
+ }
+
+ //copy the output matrix
+ const Matrix<BaseFloat>& mat = mNnet.back()->GetOutput();
+ if(out.Rows() != mat.Rows() || out.Cols() != mat.Cols()) {
+ out.Init(mat.Rows(),mat.Cols());
+ }
+ out.Copy(mat);
+
+}
+
+
+void Network::Backpropagate(const Matrix<BaseFloat>& globerr) {
+ //pass matrix to last component
+ mNnet.back()->SetErrorInput(globerr);
+
+ // back-propagation : reversed order,
+ LayeredType::reverse_iterator it;
+ for(it=mNnet.rbegin(); it!=mNnet.rend(); ++it) {
+ //first component does not backpropagate error (no predecessors)
+ if(*it != mNnet.front()) {
+ (*it)->Backpropagate();
+ }
+ //compute gradient if updatable component
+ if((*it)->IsUpdatable()) {
+ UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it);
+ comp.Gradient(); //compute gradient
+ }
+ }
+}
+
+
+void Network::AccuGradient(const Network& src, int thr, int thrN) {
+ LayeredType::iterator it;
+ LayeredType::const_iterator it2;
+
+ for(it=mNnet.begin(), it2=src.mNnet.begin(); it!=mNnet.end(); ++it,++it2) {
+ if((*it)->IsUpdatable()) {
+ UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it);
+ const UpdatableComponent& comp2 = dynamic_cast<const UpdatableComponent&>(**it2);
+ comp.AccuGradient(comp2,thr,thrN);
+ }
+ }
+}
+
+
+void Network::Update(int thr, int thrN) {
+ LayeredType::iterator it;
+
+ for(it=mNnet.begin(); it!=mNnet.end(); ++it) {
+ if((*it)->IsUpdatable()) {
+ UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it);
+ comp.Update(thr,thrN);
+ }
+ }
+}
+
+
+Network* Network::Clone() {
+ Network* net = new Network;
+ LayeredType::iterator it;
+ for(it = mNnet.begin(); it != mNnet.end(); ++it) {
+ //clone
+ net->mNnet.push_back((*it)->Clone());
+ //connect network
+ if(net->mNnet.size() > 1) {
+ Component* last = *(net->mNnet.end()-1);
+ Component* prev = *(net->mNnet.end()-2);
+ last->SetInput(prev->GetOutput());
+ prev->SetErrorInput(last->GetErrorOutput());
+ }
+ }
+
+ //copy the learning rate
+ //net->SetLearnRate(GetLearnRate());
+
+ return net;
+}
+
+
+void Network::ReadNetwork(const char* pSrc) {
+ std::ifstream in(pSrc);
+ if(!in.good()) {
+ Error(std::string("Error, cannot read model: ")+pSrc);
+ }
+ ReadNetwork(in);
+ in.close();
+}
+
+
+
+void Network::ReadNetwork(std::istream& rIn) {
+ //get the network elements from a factory
+ Component *pComp;
+ while(NULL != (pComp = ComponentFactory(rIn)))
+ mNnet.push_back(pComp);
+}
+
+
+void Network::WriteNetwork(const char* pDst) {
+ std::ofstream out(pDst);
+ if(!out.good()) {
+ Error(std::string("Error, cannot write model: ")+pDst);
+ }
+ WriteNetwork(out);
+ out.close();
+}
+
+
+void Network::WriteNetwork(std::ostream& rOut) {
+ //dump all the componetns
+ LayeredType::iterator it;
+ for(it=mNnet.begin(); it!=mNnet.end(); ++it) {
+ ComponentDumper(rOut, **it);
+ }
+}
+
+
+Component*
+Network::
+ComponentFactory(std::istream& rIn)
+{
+ rIn >> std::ws;
+ if(rIn.eof()) return NULL;
+
+ Component* pRet=NULL;
+ Component* pPred=NULL;
+
+ std::string componentTag;
+ size_t nInputs, nOutputs;
+
+ rIn >> std::ws;
+ rIn >> componentTag;
+ if(componentTag == "") return NULL; //nothing left in the file
+
+ //make it lowercase
+ std::transform(componentTag.begin(), componentTag.end(),
+ componentTag.begin(), tolower);
+
+ //the 'endblock' tag terminates the network
+ if(componentTag == "<endblock>") return NULL;
+
+
+ if(componentTag[0] != '<' || componentTag[componentTag.size()-1] != '>') {
+ Error(std::string("Invalid component tag:")+componentTag);
+ }
+
+ rIn >> std::ws;
+ rIn >> nOutputs;
+ rIn >> std::ws;
+ rIn >> nInputs;
+ assert(nInputs > 0 && nOutputs > 0);
+
+ //make coupling with predecessor
+ if(mNnet.size() == 0) {
+ pPred = NULL;
+ } else {
+ pPred = mNnet.back();
+ }
+
+ //array with list of component tags
+ static const std::string TAGS[] = {
+ "<biasedlinearity>",
+ "<sharedlinearity>",
+
+ "<sigmoid>",
+ "<softmax>",
+ "<blocksoftmax>",
+
+ "<expand>",
+ "<copy>",
+ "<transpose>",
+ "<blocklinearity>",
+ "<bias>",
+ "<window>",
+ "<log>",
+
+ "<blockarray>",
+ };
+
+ static const int n_tags = sizeof(TAGS) / sizeof(TAGS[0]);
+ int i = 0;
+ for(i=0; i<n_tags; i++) {
+ if(componentTag == TAGS[i]) break;
+ }
+
+ //switch according to position in array TAGS
+ switch(i) {
+ case 0: pRet = new BiasedLinearity(nInputs,nOutputs,pPred); break;
+ case 1: pRet = new SharedLinearity(nInputs,nOutputs,pPred); break;
+
+ case 2: pRet = new Sigmoid(nInputs,nOutputs,pPred); break;
+ case 3: pRet = new Softmax(nInputs,nOutputs,pPred); break;
+ case 4: pRet = new BlockSoftmax(nInputs,nOutputs,pPred); break;
+
+ case 5: pRet = new Expand(nInputs,nOutputs,pPred); break;
+ case 6: pRet = new Copy(nInputs,nOutputs,pPred); break;
+ case 7: pRet = new Transpose(nInputs,nOutputs,pPred); break;
+ case 8: pRet = new BlockLinearity(nInputs,nOutputs,pPred); break;
+ case 9: pRet = new Bias(nInputs,nOutputs,pPred); break;
+ case 10: pRet = new Window(nInputs,nOutputs,pPred); break;
+ case 11: pRet = new Log(nInputs,nOutputs,pPred); break;
+
+ case 12: pRet = new BlockArray(nInputs,nOutputs,pPred); break;
+
+ default: Error(std::string("Unknown Component tag:")+componentTag);
+ }
+
+ //read params if it is updatable component
+ pRet->ReadFromStream(rIn);
+ //return
+ return pRet;
+}
+
+
+void
+Network::
+ComponentDumper(std::ostream& rOut, Component& rComp)
+{
+ //use tags of all the components; or the identification codes
+ //array with list of component tags
+ static const Component::ComponentType TYPES[] = {
+ Component::BIASED_LINEARITY,
+ Component::SHARED_LINEARITY,
+
+ Component::SIGMOID,
+ Component::SOFTMAX,
+ Component::BLOCK_SOFTMAX,
+
+ Component::EXPAND,
+ Component::COPY,
+ Component::TRANSPOSE,
+ Component::BLOCK_LINEARITY,
+ Component::BIAS,
+ Component::WINDOW,
+ Component::LOG,
+
+ Component::BLOCK_ARRAY,
+ };
+ static const std::string TAGS[] = {
+ "<biasedlinearity>",
+ "<sharedlinearity>",
+
+ "<sigmoid>",
+ "<softmax>",
+ "<blocksoftmax>",
+
+ "<expand>",
+ "<copy>",
+ "<transpose>",
+ "<blocklinearity>",
+ "<bias>",
+ "<window>",
+ "<log>",
+
+ "<blockarray>",
+ };
+ static const int MAX = sizeof TYPES / sizeof TYPES[0];
+
+ int i;
+ for(i=0; i<MAX; ++i) {
+ if(TYPES[i] == rComp.GetType()) break;
+ }
+ if(i == MAX) Error("Unknown ComponentType");
+
+ //dump the component tag
+ rOut << TAGS[i] << " "
+ << rComp.GetNOutputs() << " "
+ << rComp.GetNInputs() << std::endl;
+
+ //dump the parameters (if any)
+ rComp.WriteToStream(rOut);
+}
+
+
+
+
+} //namespace
+