From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- src/CuTNetLib/.svn/text-base/cuNetwork.cc.svn-base | 380 +++++++++++++++++++++ 1 file changed, 380 insertions(+) create mode 100644 src/CuTNetLib/.svn/text-base/cuNetwork.cc.svn-base (limited to 'src/CuTNetLib/.svn/text-base/cuNetwork.cc.svn-base') diff --git a/src/CuTNetLib/.svn/text-base/cuNetwork.cc.svn-base b/src/CuTNetLib/.svn/text-base/cuNetwork.cc.svn-base new file mode 100644 index 0000000..e245699 --- /dev/null +++ b/src/CuTNetLib/.svn/text-base/cuNetwork.cc.svn-base @@ -0,0 +1,380 @@ + +#include +//#include +#include +#include +#include + +#include "cuNetwork.h" + +#include "cuDiscreteLinearity.h" +#include "cuSharedLinearity.h" +#include "cuSparseLinearity.h" +#include "cuRbm.h" +#include "cuRbmSparse.h" +#include "cuRecurrent.h" +#include "cuBlockArray.h" + +namespace TNet { + + + + + void + CuNetwork:: + ReadNetwork(const char* pSrc) + { + std::ifstream in(pSrc); + if(!in.good()) { + Error(std::string("Error, cannot read model: ")+pSrc); + } + ReadNetwork(in); + in.close(); + } + + + + void + CuNetwork:: + WriteNetwork(const char* pDst) + { + std::ofstream out(pDst); + if(!out.good()) { + Error(std::string("Error, cannot write model: ")+pDst); + } + WriteNetwork(out); + out.close(); + } + + + + void + CuNetwork:: + ReadNetwork(std::istream& rIn) + { + //get the network elements from a factory + CuComponent *pComp; + while(NULL != (pComp = ComponentFactory(rIn))) { + mNetComponents.push_back(pComp); + } + } + + + + void + CuNetwork:: + WriteNetwork(std::ostream& rOut) + { + //dump all the componetns + LayeredType::iterator it; + for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { + ComponentDumper(rOut, **it); + } + } + + + void + CuNetwork:: + SetLearnRate(BaseFloat learnRate, const char* pLearnRateFactors) + { + //parse the learn rate factors: "0.1:0.5:0.6:1.0" to std::list + std::list lr_factors; + if(NULL != pLearnRateFactors) { + //replace ':' by ' ' + std::string str(pLearnRateFactors); + size_t pos = 0; + while((pos = str.find(':',pos)) != std::string::npos) str[pos] = ' '; + while((pos = str.find(',',pos)) != std::string::npos) str[pos] = ' '; + + //parse to std::list + std::istringstream is(str); + is >> std::skipws; + BaseFloat f; + while(!is.eof()) { + if(!(is >> f).fail()) { lr_factors.push_back(f); } + else break; + } + } + + //initialize rate factors iterator + BaseFloat scale = 1.0f; + + //store global learning rate + mGlobLearnRate = learnRate; + mpLearnRateFactors = pLearnRateFactors; + + //give scaled learning rate to components + LayeredType::iterator it; + bool stopper_given = false; + for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { + if((*it)->IsUpdatable()) { + //get next scale factor + if(NULL != pLearnRateFactors) { + if(!(lr_factors.size() > 0)) { + Error("Too few learninig rate scale factors"); + } + scale = lr_factors.front(); + lr_factors.pop_front(); + } + //set scaled learning rate to the component + dynamic_cast(*it)->LearnRate(learnRate*scale); + //set the stopper component for backpropagation + if(!stopper_given && (learnRate*scale > 0.0)) { + mpPropagErrorStopper = *it; stopper_given = true; + } + } + } + if(lr_factors.size() > 0) { + Error("Too much learninig rate scale factors"); + } + } + + + BaseFloat + CuNetwork:: + GetLearnRate() + { + return mGlobLearnRate; + } + + + void + CuNetwork:: + PrintLearnRate() + { + assert(mNetComponents.size() > 0); + std::cout << "Learning rate: global " << mGlobLearnRate; + std::cout << " components' "; + for(size_t i=0; iIsUpdatable()) { + std::cout << " " << dynamic_cast(mNetComponents[i])->LearnRate(); + } + } + std::cout << "\n" << std::flush; + } + + + + void + CuNetwork:: + SetMomentum(BaseFloat momentum) + { + LayeredType::iterator it; + for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { + if((*it)->IsUpdatable()) { + dynamic_cast(*it)->Momentum(momentum); + } + } + } + + void + CuNetwork:: + SetWeightcost(BaseFloat weightcost) + { + LayeredType::iterator it; + for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { + if((*it)->IsUpdatable()) { + dynamic_cast(*it)->Weightcost(weightcost); + } + } + } + + void + CuNetwork:: + SetL1(BaseFloat l1) + { + LayeredType::iterator it; + for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { + if((*it)->GetType() == CuComponent::SPARSE_LINEARITY) { + dynamic_cast(*it)->L1(l1); + } + } + } + + void + CuNetwork:: + SetGradDivFrm(bool div) + { + LayeredType::iterator it; + for(it=mNetComponents.begin(); it!=mNetComponents.end(); ++it) { + if((*it)->IsUpdatable()) { + dynamic_cast(*it)->GradDivFrm(div); + } + } + } + + + CuComponent* + CuNetwork:: + ComponentFactory(std::istream& rIn) + { + rIn >> std::ws; + if(rIn.eof()) return NULL; + + CuComponent* pRet=NULL; + CuComponent* pPred=NULL; + + std::string componentTag; + size_t nInputs, nOutputs; + + rIn >> std::ws; + rIn >> componentTag; + if(componentTag == "") return NULL; //nothing left in the file + + //make it lowercase + std::transform(componentTag.begin(), componentTag.end(), + componentTag.begin(), tolower); + + if(componentTag[0] != '<' || componentTag[componentTag.size()-1] != '>') { + Error(std::string("Invalid component tag:")+componentTag); + } + + //the 'endblock' tag terminates the network + if(componentTag == "") return NULL; + + rIn >> std::ws; + rIn >> nOutputs; + rIn >> std::ws; + rIn >> nInputs; + assert(nInputs > 0 && nOutputs > 0); + + //make coupling with predecessor + if(mNetComponents.size() != 0) { + pPred = mNetComponents.back(); + } + + //array with list of component tags + static const std::string TAGS[] = { + "", + "", + "", + "", + "", + "", + "", + + "", + "", + + "", + "", + "", + "", + "", + "", + "", + + "", + }; + + static const int n_tags = sizeof(TAGS) / sizeof(TAGS[0]); + int i; + for(i=0; iReadFromStream(rIn); + + //return + return pRet; + } + + + void + CuNetwork:: + ComponentDumper(std::ostream& rOut, CuComponent& rComp) + { + //use tags of all the components; or the identification codes + //array with list of component tags + static const CuComponent::ComponentType TYPES[] = { + CuComponent::BIASED_LINEARITY, + CuComponent::DISCRETE_LINEARITY, + CuComponent::SHARED_LINEARITY, + CuComponent::SPARSE_LINEARITY, + CuComponent::RBM, + CuComponent::RBM_SPARSE, + CuComponent::RECURRENT, + + CuComponent::SIGMOID, + CuComponent::SOFTMAX, + + CuComponent::EXPAND, + CuComponent::COPY, + CuComponent::TRANSPOSE, + CuComponent::BLOCK_LINEARITY, + CuComponent::BIAS, + CuComponent::WINDOW, + CuComponent::LOG, + + CuComponent::BLOCK_ARRAY, + }; + static const std::string TAGS[] = { + "", + "", + "", + "", + "", + "", + "", + + "", + "", + + "", + "", + "", + "", + "", + "", + "", + + "", + }; + static const int MAX = sizeof TYPES / sizeof TYPES[0]; + + int i; + for(i=0; i