From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- src/TNetLib/Platform.h | 402 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 402 insertions(+) create mode 100644 src/TNetLib/Platform.h (limited to 'src/TNetLib/Platform.h') diff --git a/src/TNetLib/Platform.h b/src/TNetLib/Platform.h new file mode 100644 index 0000000..628b9cd --- /dev/null +++ b/src/TNetLib/Platform.h @@ -0,0 +1,402 @@ +#ifndef _TNET_PLATFORM_H +#define _TNET_PLATFORM_H + +/** + * \file Platform.h + * \brief DNN training class multicore version + */ + +#include "Thread.h" +#include "Matrix.h" + +#include "Features.h" +#include "Labels.h" + +#include "Cache.h" +#include "Nnet.h" +#include "ObjFun.h" + +#include "Mutex.h" +#include "Semaphore.h" +#include "Barrier.h" +#include "Thread.h" + +#include +#include +#include + +namespace TNet { + +class PlatformThread; + +class Platform { + +/* +* Variables to be initialized directly from the main function +*/ +public: + FeatureRepository feature_; ///< Features specified in the input arguments and script file + LabelRepository label_; ///< Labels specified in the lable map file + + Network nnet_transf_; ///< NNet transform + Network nnet_; ///< N network + ObjectiveFunction* obj_fun_; ///< Specified in the ObjectiveFunction + + int bunchsize_; + int cachesize_; + bool randomize_; + + int start_frm_ext_; + int end_frm_ext_; + + int trace_; + bool crossval_; + + long int seed_; + + /* + * Variables to be used internally during the multi-threaded training + */ + private: + Semaphore semaphore_read_; + + std::vector*> > feature_buf_; + std::vector*> > label_buf_; + std::vector mutex_buf_; + + std::vector nnet_transf2_; + + std::vector cache_; + + std::vector nnet2_; + std::vector obj_fun2_; + std::vector sync_mask_; + + Barrier barrier_; + bool end_reading_; + std::vector tim_; + std::vector tim_accu_; + + int num_thr_; + Semaphore semaphore_endtrain_; + Semaphore semaphore_endtrain2_; + + public: + Mutex cout_mutex_; + + /* + * Methods + */ + public: + Platform() + : bunchsize_(0), cachesize_(0), randomize_(false), + start_frm_ext_(0), end_frm_ext_(0), trace_(0), + crossval_(false), seed_(0), + end_reading_(false), num_thr_(0) + { } + + ~Platform() + { + for(size_t i=0; i(arg); + platform_.Thread(thr_id); + } + + private: + Platform& platform_; +}; + + + + + +void Platform::RunTrain(int num_thr) { + num_thr_ = num_thr; + + /* + * Initialize parallel training + */ + feature_buf_.resize(num_thr); + label_buf_.resize(num_thr); + mutex_buf_.resize(num_thr); + cache_.resize(num_thr); + sync_mask_.resize(num_thr); + barrier_.SetThreshold(num_thr); + + tim_.resize(num_thr); + tim_accu_.resize(num_thr,0.0); + + int bunchsize = bunchsize_/num_thr; + int cachesize = (cachesize_/num_thr/bunchsize)*bunchsize; + std::cout << "Bunchsize:" << bunchsize << "*" << num_thr << "=" << bunchsize*num_thr + << " Cachesize:" << cachesize << "*" << num_thr << "=" << cachesize*num_thr << "\n"; + for(int i=0; iClone()); + //enable threads to sync weights + sync_mask_[i] = true; + } + + /* + * Run training threads + */ + std::vector threads; + for(intptr_t i=0; iStart(reinterpret_cast(i)); + threads.push_back(t); + } + + /* + * Read the training data + */ + ReadData(); + + /* + * Wait for training to finish + */ + semaphore_endtrain2_.Wait(); + +} + + + +void Platform::ReadData() try { + cout_mutex_.Lock(); + std::cout << "queuesize " << feature_.QueueSize() << "\n"; + cout_mutex_.Unlock(); + + int thr = 0; + for(feature_.Rewind();!feature_.EndOfList();feature_.MoveNext()) { + Matrix* fea = new Matrix; + Matrix* lab = new Matrix; + + feature_.ReadFullMatrix(*fea); + label_.GenDesiredMatrix(*lab, + fea->Rows()-start_frm_ext_-end_frm_ext_, + feature_.CurrentHeader().mSamplePeriod, + feature_.Current().Logical().c_str()); + + + fea->CheckData(feature_.Current().Logical()); + + mutex_buf_[thr].Lock(); + feature_buf_[thr].push_back(fea); + label_buf_[thr].push_back(lab); + mutex_buf_[thr].Unlock(); + + //suspend reading when shortest buffer has 50 matrices + if(thr == 0) { + int minsize=1e6; + for(size_t i=0; i 20) semaphore_read_.Wait(); + } + + thr = (thr+1) % num_thr_; + } + + std::cout << "[Reading finished]\n" << std::flush; + end_reading_ = true; + +} catch (std::exception& rExc) { + std::cerr << "Exception thrown" << std::endl; + std::cerr << rExc.what() << std::endl; + exit(1); +} + +void Platform::Thread(int thr_id) try { + + const int thr = thr_id; //make id const for safety! + + while(1) { + //fill the cache + while(!cache_[thr].Full() && !(end_reading_ && (feature_buf_[thr].size() == 0))) { + + if(feature_buf_[thr].size() <= 5) { + semaphore_read_.Post();//wake the reader + } + if(feature_buf_[thr].size() == 0) { + cout_mutex_.Lock(); + std::cout << "Thread" << thr << ",waiting for data\n"; + cout_mutex_.Unlock(); + sleep(1); + } else { + //get the matrices + mutex_buf_[thr].Lock(); + Matrix* fea = feature_buf_[thr].front(); + Matrix* lab = label_buf_[thr].front(); + feature_buf_[thr].pop_front(); + label_buf_[thr].pop_front(); + mutex_buf_[thr].Unlock(); + + //transform the features + Matrix fea_transf; + nnet_transf2_[thr]->Propagate(*fea,fea_transf); + + //trim the ext + SubMatrix fea_trim( + fea_transf, + start_frm_ext_, + fea_transf.Rows()-start_frm_ext_-end_frm_ext_, + 0, + fea_transf.Cols() + ); + + //add to cache + cache_[thr].AddData(fea_trim,*lab); + + delete fea; delete lab; + } + } + + //no more data, end training... + if(cache_[thr].Empty()) break; + + if(randomize_) { cache_[thr].Randomize(); } + + + //std::cout << "Thread" << thr << ", Cache#" << nr_cache++ << "\n"; + + //train from cache + Matrix fea2,lab2,out,err; + while(!cache_[thr].Empty()) { + cache_[thr].GetBunch(fea2,lab2); + nnet2_[thr]->Propagate(fea2,out); + obj_fun2_[thr]->Evaluate(out,lab2,&err); + + if(!crossval_) { + nnet2_[thr]->Backpropagate(err); + + tim_[thr].Start(); + barrier_.Wait();//*********/ + tim_[thr].End(); tim_accu_[thr] += tim_[thr].Val(); + + //sum the gradient and bunchsize + for(int i=0; iMergeStats(*obj_fun2_[i]); + } + + cout_mutex_.Lock(); + std::cout << "Barrier waiting times per thread\n"; + std::copy(tim_accu_.begin(),tim_accu_.end(),std::ostream_iterator(std::cout," ")); + std::cout << "\n"; + cout_mutex_.Unlock(); + } + + cout_mutex_.Lock(); + std::cout << "[Thread" << thr << " finished]\n"; + cout_mutex_.Unlock(); + + if(thr == 0) { + semaphore_endtrain2_.Post(); + } +} catch (std::exception& rExc) { + std::cerr << "Exception thrown" << std::endl; + std::cerr << rExc.what() << std::endl; + exit(1); +} + + + +}//namespace TNet + +#endif -- cgit v1.2.3-70-g09d2