diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 | 
| commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
| tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/TNetLib/.svn/text-base/Nnet.cc.svn-base | |
| download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip | |
First commit
Diffstat (limited to 'src/TNetLib/.svn/text-base/Nnet.cc.svn-base')
| -rw-r--r-- | src/TNetLib/.svn/text-base/Nnet.cc.svn-base | 360 | 
1 files changed, 360 insertions, 0 deletions
| diff --git a/src/TNetLib/.svn/text-base/Nnet.cc.svn-base b/src/TNetLib/.svn/text-base/Nnet.cc.svn-base new file mode 100644 index 0000000..4b364ac --- /dev/null +++ b/src/TNetLib/.svn/text-base/Nnet.cc.svn-base @@ -0,0 +1,360 @@ + +#include <algorithm> +//#include <locale> +#include <cctype> + +#include "Nnet.h" +#include "CRBEDctFeat.h" +#include "BlockArray.h" + +namespace TNet { + + + + +void Network::Feedforward(const Matrix<BaseFloat>& in, Matrix<BaseFloat>& out,  +                          size_t start_frm_ext, size_t end_frm_ext) { +  //empty network: copy input to output  +  if(mNnet.size() == 0) { +    if(out.Rows() != in.Rows() || out.Cols() != in.Cols()) { +      out.Init(in.Rows(),in.Cols()); +    } +    out.Copy(in); +    return; +  } +   +  //short input: propagate in one block   +  if(in.Rows() < 5000) {  +    Propagate(in,out); +  } else {//long input: propagate per parts +    //initialize +    out.Init(in.Rows(),GetNOutputs()); +    Matrix<BaseFloat> tmp_in, tmp_out; +    int done=0, block=1024; +    //propagate first part +    tmp_in.Init(block+end_frm_ext,in.Cols()); +    tmp_in.Copy(in.Range(0,block+end_frm_ext,0,in.Cols())); +    Propagate(tmp_in,tmp_out); +    out.Range(0,block,0,tmp_out.Cols()).Copy( +      tmp_out.Range(0,block,0,tmp_out.Cols()) +    ); +    done += block; +    //propagate middle parts +    while((done+2*block) < in.Rows()) { +      tmp_in.Init(block+start_frm_ext+end_frm_ext,in.Cols()); +      tmp_in.Copy(in.Range(done-start_frm_ext, block+start_frm_ext+end_frm_ext, 0,in.Cols()));      Propagate(tmp_in,tmp_out); +      out.Range(done,block,0,tmp_out.Cols()).Copy( +        tmp_out.Range(start_frm_ext,block,0,tmp_out.Cols()) +      ); +      done += block; +    } +    //propagate last part +    tmp_in.Init(in.Rows()-done+start_frm_ext,in.Cols()); +    tmp_in.Copy(in.Range(done-start_frm_ext,in.Rows()-done+start_frm_ext,0,in.Cols())); +    Propagate(tmp_in,tmp_out); +    out.Range(done,out.Rows()-done,0,out.Cols()).Copy( +      tmp_out.Range(start_frm_ext,tmp_out.Rows()-start_frm_ext,0,tmp_out.Cols())    +    ); + +    done += tmp_out.Rows()-start_frm_ext; +    assert(done == out.Rows()); +  } +} + + +void Network::Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>& out) { +  //empty network: copy input to output  +  if(mNnet.size() == 0) { +    if(out.Rows() != in.Rows() || out.Cols() != in.Cols()) { +      out.Init(in.Rows(),in.Cols()); +    } +    out.Copy(in); +    return; +  } +   +  //this will keep pointer to matrix 'in', for backprop +  mNnet.front()->SetInput(in);  + +  //propagate +  LayeredType::iterator it; +  for(it=mNnet.begin(); it!=mNnet.end(); ++it) { +    (*it)->Propagate(); +  } + +  //copy the output matrix +  const Matrix<BaseFloat>& mat = mNnet.back()->GetOutput(); +  if(out.Rows() != mat.Rows() || out.Cols() != mat.Cols()) { +    out.Init(mat.Rows(),mat.Cols()); +  } +  out.Copy(mat); + +} + + +void Network::Backpropagate(const Matrix<BaseFloat>& globerr) { +  //pass matrix to last component +  mNnet.back()->SetErrorInput(globerr); + +  // back-propagation : reversed order, +  LayeredType::reverse_iterator it; +  for(it=mNnet.rbegin(); it!=mNnet.rend(); ++it) { +    //first component does not backpropagate error (no predecessors) +    if(*it != mNnet.front()) { +      (*it)->Backpropagate(); +    } +    //compute gradient if updatable component +    if((*it)->IsUpdatable()) { +      UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it); +      comp.Gradient(); //compute gradient  +    } +  } +} + + +void Network::AccuGradient(const Network& src, int thr, int thrN) { +  LayeredType::iterator it; +  LayeredType::const_iterator it2; + +  for(it=mNnet.begin(), it2=src.mNnet.begin(); it!=mNnet.end(); ++it,++it2) { +    if((*it)->IsUpdatable()) { +      UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it); +      const UpdatableComponent& comp2 = dynamic_cast<const UpdatableComponent&>(**it2); +      comp.AccuGradient(comp2,thr,thrN); +    } +  } +} + + +void Network::Update(int thr, int thrN) { +  LayeredType::iterator it; + +  for(it=mNnet.begin(); it!=mNnet.end(); ++it) { +    if((*it)->IsUpdatable()) { +      UpdatableComponent& comp = dynamic_cast<UpdatableComponent&>(**it); +      comp.Update(thr,thrN); +    } +  } +} + + +Network* Network::Clone() { +  Network* net = new Network; +  LayeredType::iterator it; +  for(it = mNnet.begin(); it != mNnet.end(); ++it) { +    //clone +    net->mNnet.push_back((*it)->Clone()); +    //connect network +    if(net->mNnet.size() > 1) { +      Component* last = *(net->mNnet.end()-1); +      Component* prev = *(net->mNnet.end()-2); +      last->SetInput(prev->GetOutput()); +      prev->SetErrorInput(last->GetErrorOutput()); +    } +  } + +  //copy the learning rate +  //net->SetLearnRate(GetLearnRate()); + +  return net; +} + + +void Network::ReadNetwork(const char* pSrc) { +  std::ifstream in(pSrc); +  if(!in.good()) { +    Error(std::string("Error, cannot read model: ")+pSrc); +  } +  ReadNetwork(in); +  in.close(); +} + +   + +void Network::ReadNetwork(std::istream& rIn) { +  //get the network elements from a factory +  Component *pComp; +  while(NULL != (pComp = ComponentFactory(rIn)))  +    mNnet.push_back(pComp); +} + + +void Network::WriteNetwork(const char* pDst) { +  std::ofstream out(pDst); +  if(!out.good()) { +    Error(std::string("Error, cannot write model: ")+pDst); +  } +  WriteNetwork(out); +  out.close(); +} + + +void Network::WriteNetwork(std::ostream& rOut) { +  //dump all the componetns +  LayeredType::iterator it; +  for(it=mNnet.begin(); it!=mNnet.end(); ++it) { +    ComponentDumper(rOut, **it); +  } +} +  + +Component* +Network:: +ComponentFactory(std::istream& rIn) +{ +  rIn >> std::ws; +  if(rIn.eof()) return NULL; + +  Component* pRet=NULL; +  Component* pPred=NULL; + +  std::string componentTag; +  size_t nInputs, nOutputs; + +  rIn >> std::ws; +  rIn >> componentTag; +  if(componentTag == "") return NULL; //nothing left in the file + +  //make it lowercase +  std::transform(componentTag.begin(), componentTag.end(),  +                 componentTag.begin(), tolower); + +  //the 'endblock' tag terminates the network +  if(componentTag == "<endblock>") return NULL; + +   +  if(componentTag[0] != '<' || componentTag[componentTag.size()-1] != '>') { +    Error(std::string("Invalid component tag:")+componentTag); +  } + +  rIn >> std::ws; +  rIn >> nOutputs; +  rIn >> std::ws; +  rIn >> nInputs; +  assert(nInputs > 0 && nOutputs > 0); + +  //make coupling with predecessor +  if(mNnet.size() == 0) { +    pPred = NULL; +  } else { +    pPred = mNnet.back(); +  } +   +  //array with list of component tags +  static const std::string TAGS[] = { +    "<biasedlinearity>", +    "<sharedlinearity>", +     +    "<sigmoid>", +    "<softmax>", +    "<blocksoftmax>", + +    "<expand>", +    "<copy>", +    "<transpose>", +    "<blocklinearity>", +    "<bias>", +    "<window>", +    "<log>", + +    "<blockarray>", +  }; + +  static const int n_tags = sizeof(TAGS) / sizeof(TAGS[0]); +  int i = 0; +  for(i=0; i<n_tags; i++) { +    if(componentTag == TAGS[i]) break; +  } +   +  //switch according to position in array TAGS +  switch(i) { +    case 0: pRet = new BiasedLinearity(nInputs,nOutputs,pPred); break; +    case 1: pRet = new SharedLinearity(nInputs,nOutputs,pPred); break; + +    case 2: pRet = new Sigmoid(nInputs,nOutputs,pPred); break; +    case 3: pRet = new Softmax(nInputs,nOutputs,pPred); break; +    case 4: pRet = new BlockSoftmax(nInputs,nOutputs,pPred); break; + +    case 5: pRet = new Expand(nInputs,nOutputs,pPred); break; +    case 6: pRet = new Copy(nInputs,nOutputs,pPred); break; +    case 7: pRet = new Transpose(nInputs,nOutputs,pPred); break; +    case 8: pRet = new BlockLinearity(nInputs,nOutputs,pPred); break; +    case 9: pRet = new Bias(nInputs,nOutputs,pPred); break; +    case 10: pRet = new Window(nInputs,nOutputs,pPred); break; +    case 11: pRet = new Log(nInputs,nOutputs,pPred); break; +     +    case 12: pRet = new BlockArray(nInputs,nOutputs,pPred); break; + +    default: Error(std::string("Unknown Component tag:")+componentTag); +  } +  +  //read params if it is updatable component +  pRet->ReadFromStream(rIn); +  //return +  return pRet; +} + + +void +Network:: +ComponentDumper(std::ostream& rOut, Component& rComp) +{ +  //use tags of all the components; or the identification codes +  //array with list of component tags +  static const Component::ComponentType TYPES[] = { +    Component::BIASED_LINEARITY, +    Component::SHARED_LINEARITY, +     +    Component::SIGMOID, +    Component::SOFTMAX, +    Component::BLOCK_SOFTMAX, + +    Component::EXPAND, +    Component::COPY, +    Component::TRANSPOSE, +    Component::BLOCK_LINEARITY, +    Component::BIAS, +    Component::WINDOW, +    Component::LOG, + +    Component::BLOCK_ARRAY, +  }; +  static const std::string TAGS[] = { +    "<biasedlinearity>", +    "<sharedlinearity>", + +    "<sigmoid>", +    "<softmax>", +    "<blocksoftmax>", + +    "<expand>", +    "<copy>", +    "<transpose>", +    "<blocklinearity>", +    "<bias>", +    "<window>", +    "<log>", + +    "<blockarray>", +  }; +  static const int MAX = sizeof TYPES / sizeof TYPES[0]; + +  int i; +  for(i=0; i<MAX; ++i) { +    if(TYPES[i] == rComp.GetType()) break; +  } +  if(i == MAX) Error("Unknown ComponentType"); +   +  //dump the component tag +  rOut << TAGS[i] << " "  +       << rComp.GetNOutputs() << " "  +       << rComp.GetNInputs() << std::endl; + +  //dump the parameters (if any) +  rComp.WriteToStream(rOut); +} + + + +   +} //namespace + | 
