diff options
Diffstat (limited to 'src/TNetLib/Nnet.cc')
-rw-r--r-- | src/TNetLib/Nnet.cc | 360 |
1 files changed, 360 insertions, 0 deletions
diff --git a/src/TNetLib/Nnet.cc b/src/TNetLib/Nnet.cc new file mode 100644 index 0000000..4b364ac --- /dev/null +++ b/src/TNetLib/Nnet.cc @@ -0,0 +1,360 @@ + +#include <algorithm> +//#include <locale> +#include <cctype> + +#include "Nnet.h" +#include "CRBEDctFeat.h" +#include "BlockArray.h" + +namespace TNet { + + + + +void Network::Feedforward(const Matrix<BaseFloat>& in, Matrix<BaseFloat>& out, + size_t start_frm_ext, size_t end_frm_ext) { + //empty network: copy input to output + if(mNnet.size() == 0) { + if(out.Rows() != in.Rows() || out.Cols() != in.Cols()) { + out.Init(in.Rows(),in.Cols()); + } + out.Copy(in); + return; + } + + //short input: propagate in one block + if(in.Rows() < 5000) { + Propagate(in,out); + } else {//long input: propagate per parts + //initialize + out.Init(in.Rows(),GetNOutputs()); + Matrix<BaseFloat> tmp_in, tmp_out; + int done=0, block=1024; + //propagate first part + tmp_in.Init(block+end_frm_ext,in.Cols()); + tmp_in.Copy(in.Range(0,block+end_frm_ext,0,in.Cols())); + Propagate(tmp_in,tmp_out); + out.Range(0,block,0,tmp_out.Cols()).Copy( + tmp_out.Range(0,block,0,tmp_out.Cols()) + ); + done += block; + //propagate middle parts + while((done+2*block) < in.Rows()) { + tmp_in.Init(block+start_frm_ext+end_frm_ext,in.Cols()); + tmp_in.Copy(in.Range(done-start_frm_ext, block+start_frm_ext+end_frm_ext, 0,in.Cols())); Propagate(tmp_in,tmp_out); + out.Range(done,block,0,tmp_out.Cols()).Copy( + tmp_out.Range(start_frm_ext,block,0,tmp_out.Cols()) + ); + done += block; + } + //propagate last part + tmp_in.Init(in.Rows()-done+start_frm_ext,in.Cols()); + tmp_in.Copy(in.Range(done-start_frm_ext,in.Rows()-done+start_frm_ext,0,in.Cols())); + Propagate(tmp_in,tmp_out); + out.Range(done,out.Rows()-done,0,out.Cols()).Copy( + tmp_out.Range(start_frm_ext,tmp_out.Rows()-start_frm_ext,0,tmp_out.Cols()) + ); + + done += tmp_out.Rows()-start_frm_ext; + assert(done == out.Rows()); + } +} + + +void Network::Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>& out) { + //empty network: copy input to output + if(mNnet.size() == 0) { + if(out.Rows() != in.Rows() || out.Cols() != in.Cols()) { + out.Init(in.Rows(),in.Cols()); + } + out.Copy(in); + return; + } + + //this will keep pointer to matrix 'in', for backprop + mNnet.front()->SetInput(in); + + //propagate + LayeredType::iterator it; + for(it=mNnet.begin(); it!=mNnet.end(); ++it) { + (*it)->Propagate(); + } + + //copy the output matrix + const Matrix<BaseFloat>& mat = mNnet.back()->GetOutput(); + if(out.Rows() != mat.Rows() || out.Cols() != mat.Cols()) { + out.Init(mat.Rows(),mat.Cols()); + } + out.Copy(mat); + +} + + +void Network::Backpropagate(const Matrix<BaseFloat>& globerr) { + //pass matrix to last component + mNnet.back()->SetErrorInput(globerr); + + // back-propagation : reversed order, + LayeredType::reverse_iterator it; + for(it=mNnet.rbegin(); it!=mNnet.rend(); ++it) { + //first component does not backpropagate error (no predecessors) + if(*it != mNnet.front()) { + (*it)->Backpropagate(); + } + //compute gradient if updatable component + if((*it)->IsUpdatable()) { + UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it); + comp.Gradient(); //compute gradient + } + } +} + + +void Network::AccuGradient(const Network& src, int thr, int thrN) { + LayeredType::iterator it; + LayeredType::const_iterator it2; + + for(it=mNnet.begin(), it2=src.mNnet.begin(); it!=mNnet.end(); ++it,++it2) { + if((*it)->IsUpdatable()) { + UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it); + const UpdatableComponent& comp2 = dynamic_cast<const UpdatableComponent&>(**it2); + comp.AccuGradient(comp2,thr,thrN); + } + } +} + + +void Network::Update(int thr, int thrN) { + LayeredType::iterator it; + + for(it=mNnet.begin(); it!=mNnet.end(); ++it) { + if((*it)->IsUpdatable()) { + UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it); + comp.Update(thr,thrN); + } + } +} + + +Network* Network::Clone() { + Network* net = new Network; + LayeredType::iterator it; + for(it = mNnet.begin(); it != mNnet.end(); ++it) { + //clone + net->mNnet.push_back((*it)->Clone()); + //connect network + if(net->mNnet.size() > 1) { + Component* last = *(net->mNnet.end()-1); + Component* prev = *(net->mNnet.end()-2); + last->SetInput(prev->GetOutput()); + prev->SetErrorInput(last->GetErrorOutput()); + } + } + + //copy the learning rate + //net->SetLearnRate(GetLearnRate()); + + return net; +} + + +void Network::ReadNetwork(const char* pSrc) { + std::ifstream in(pSrc); + if(!in.good()) { + Error(std::string("Error, cannot read model: ")+pSrc); + } + ReadNetwork(in); + in.close(); +} + + + +void Network::ReadNetwork(std::istream& rIn) { + //get the network elements from a factory + Component *pComp; + while(NULL != (pComp = ComponentFactory(rIn))) + mNnet.push_back(pComp); +} + + +void Network::WriteNetwork(const char* pDst) { + std::ofstream out(pDst); + if(!out.good()) { + Error(std::string("Error, cannot write model: ")+pDst); + } + WriteNetwork(out); + out.close(); +} + + +void Network::WriteNetwork(std::ostream& rOut) { + //dump all the componetns + LayeredType::iterator it; + for(it=mNnet.begin(); it!=mNnet.end(); ++it) { + ComponentDumper(rOut, **it); + } +} + + +Component* +Network:: +ComponentFactory(std::istream& rIn) +{ + rIn >> std::ws; + if(rIn.eof()) return NULL; + + Component* pRet=NULL; + Component* pPred=NULL; + + std::string componentTag; + size_t nInputs, nOutputs; + + rIn >> std::ws; + rIn >> componentTag; + if(componentTag == "") return NULL; //nothing left in the file + + //make it lowercase + std::transform(componentTag.begin(), componentTag.end(), + componentTag.begin(), tolower); + + //the 'endblock' tag terminates the network + if(componentTag == "<endblock>") return NULL; + + + if(componentTag[0] != '<' || componentTag[componentTag.size()-1] != '>') { + Error(std::string("Invalid component tag:")+componentTag); + } + + rIn >> std::ws; + rIn >> nOutputs; + rIn >> std::ws; + rIn >> nInputs; + assert(nInputs > 0 && nOutputs > 0); + + //make coupling with predecessor + if(mNnet.size() == 0) { + pPred = NULL; + } else { + pPred = mNnet.back(); + } + + //array with list of component tags + static const std::string TAGS[] = { + "<biasedlinearity>", + "<sharedlinearity>", + + "<sigmoid>", + "<softmax>", + "<blocksoftmax>", + + "<expand>", + "<copy>", + "<transpose>", + "<blocklinearity>", + "<bias>", + "<window>", + "<log>", + + "<blockarray>", + }; + + static const int n_tags = sizeof(TAGS) / sizeof(TAGS[0]); + int i = 0; + for(i=0; i<n_tags; i++) { + if(componentTag == TAGS[i]) break; + } + + //switch according to position in array TAGS + switch(i) { + case 0: pRet = new BiasedLinearity(nInputs,nOutputs,pPred); break; + case 1: pRet = new SharedLinearity(nInputs,nOutputs,pPred); break; + + case 2: pRet = new Sigmoid(nInputs,nOutputs,pPred); break; + case 3: pRet = new Softmax(nInputs,nOutputs,pPred); break; + case 4: pRet = new BlockSoftmax(nInputs,nOutputs,pPred); break; + + case 5: pRet = new Expand(nInputs,nOutputs,pPred); break; + case 6: pRet = new Copy(nInputs,nOutputs,pPred); break; + case 7: pRet = new Transpose(nInputs,nOutputs,pPred); break; + case 8: pRet = new BlockLinearity(nInputs,nOutputs,pPred); break; + case 9: pRet = new Bias(nInputs,nOutputs,pPred); break; + case 10: pRet = new Window(nInputs,nOutputs,pPred); break; + case 11: pRet = new Log(nInputs,nOutputs,pPred); break; + + case 12: pRet = new BlockArray(nInputs,nOutputs,pPred); break; + + default: Error(std::string("Unknown Component tag:")+componentTag); + } + + //read params if it is updatable component + pRet->ReadFromStream(rIn); + //return + return pRet; +} + + +void +Network:: +ComponentDumper(std::ostream& rOut, Component& rComp) +{ + //use tags of all the components; or the identification codes + //array with list of component tags + static const Component::ComponentType TYPES[] = { + Component::BIASED_LINEARITY, + Component::SHARED_LINEARITY, + + Component::SIGMOID, + Component::SOFTMAX, + Component::BLOCK_SOFTMAX, + + Component::EXPAND, + Component::COPY, + Component::TRANSPOSE, + Component::BLOCK_LINEARITY, + Component::BIAS, + Component::WINDOW, + Component::LOG, + + Component::BLOCK_ARRAY, + }; + static const std::string TAGS[] = { + "<biasedlinearity>", + "<sharedlinearity>", + + "<sigmoid>", + "<softmax>", + "<blocksoftmax>", + + "<expand>", + "<copy>", + "<transpose>", + "<blocklinearity>", + "<bias>", + "<window>", + "<log>", + + "<blockarray>", + }; + static const int MAX = sizeof TYPES / sizeof TYPES[0]; + + int i; + for(i=0; i<MAX; ++i) { + if(TYPES[i] == rComp.GetType()) break; + } + if(i == MAX) Error("Unknown ComponentType"); + + //dump the component tag + rOut << TAGS[i] << " " + << rComp.GetNOutputs() << " " + << rComp.GetNInputs() << std::endl; + + //dump the parameters (if any) + rComp.WriteToStream(rOut); +} + + + + +} //namespace + |