diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-04-26 22:53:13 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-04-26 22:53:13 +0800 |
commit | a8a7bf5f9b9a1eb0d41f839afd06cc532356a902 (patch) | |
tree | a70dbb38d76d30a84a8298096a40830ba7b528a7 | |
parent | be756a55086b5a8f62b979b456475c86ec2cfb61 (diff) | |
download | ranksvm-a8a7bf5f9b9a1eb0d41f839afd06cc532356a902.tar.gz ranksvm-a8a7bf5f9b9a1eb0d41f839afd06cc532356a902.tar.bz2 ranksvm-a8a7bf5f9b9a1eb0d41f839afd06cc532356a902.zip |
getAllData
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | main.cpp | 40 | ||||
-rw-r--r-- | tools/dataProvider.h | 10 | ||||
-rw-r--r-- | tools/reidFDataProvider.h | 8 |
4 files changed, 39 insertions, 21 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d8a4e3..18ef86f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,5 +12,5 @@ FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED ) INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) set(SOURCE_FILES main.cpp ./model/ranksvm.cpp ./model/ranksvmtn.cpp ./model/rankaccu.cpp) -add_executable(ranksvm ${SOURCE_FILES} model/rankaccu.h model/rankaccu.cpp) +add_executable(ranksvm ${SOURCE_FILES} model/rankaccu.h model/rankaccu.cpp tools/reidFDataProvider.h) TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} )
\ No newline at end of file @@ -24,7 +24,7 @@ int train(DataProvider &dp) { DataList D; LOG(INFO)<<"Training started"; - dp.getDataSet(D); + dp.getAllData(D); LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features"; rsvm->train(D); std::vector<double> L; @@ -49,30 +49,32 @@ int predict(DataProvider &dp) { std::vector<double> L; LOG(INFO)<<"Prediction started"; - dp.getDataSet(D); - LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features"; - rsvm->predict(D,L); + std::ofstream fout; + if (vm.count("output")) + fout.open(vm["output"].as<std::string>().c_str()); - if (vm.count("validate")) + while (!dp.EOFile()) { - rank_accu(D,L); + dp.getDataSet(D); + LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features"; + rsvm->predict(D,L); + + if (vm.count("validate")) + { + rank_accu(D,L); + } + + if (vm.count("output")) + for (int i=0; i<L.size();++i) + fout<<L[i]<<std::endl; + else if (!vm.count("validate")) + for (int i=0; i<L.size();++i) + std::cout<<L[i]<<std::endl; } + LOG(INFO)<<"Finished"; if (vm.count("output")) - { - LOG(INFO)<<"Finished,saving prediction"; - std::ofstream fout(vm["output"].as<std::string>().c_str()); - - for (int i=0; i<L.size();++i) - fout<<L[i]<<std::endl; fout.close(); - } - else if (!vm.count("validate")) - { - LOG(INFO)<<"Finished"; - for (int i=0; i<L.size();++i) - std::cout<<L[i]<<std::endl; - } dp.close(); delete rsvm; return 0; diff --git a/tools/dataProvider.h b/tools/dataProvider.h index 2c3169a..64bfa2d 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -51,7 +51,15 @@ public: DataProvider():eof(false){}; bool EOFile(){return eof;} - + int getAllData(DataList &out){\ + out.clear(); + DataList buf; + while (!EOFile()) + { + getDataSet(buf); + out.getData().insert(out.getData().end(),buf.getData().begin(),buf.getData().end()); + } + } virtual int getDataSet(DataList &out) = 0; virtual int open()=0; virtual int close()=0; diff --git a/tools/reidFDataProvider.h b/tools/reidFDataProvider.h new file mode 100644 index 0000000..9fa833a --- /dev/null +++ b/tools/reidFDataProvider.h @@ -0,0 +1,8 @@ +// +// Created by joe on 4/26/15. +// + +#ifndef RANKSVM_REIDFDATAPROVIDER_H +#define RANKSVM_REIDFDATAPROVIDER_H + +#endif //RANKSVM_REIDFDATAPROVIDER_H |