From 20587ac550cfcb2d7b3d6ec16e46ba1a8d0af869 Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Wed, 13 May 2015 13:35:03 +0800 Subject: added split --- CMakeLists.txt | 9 ++- main.cpp | 135 ----------------------------------- split.cpp | 76 ++++++++++++++++++++ tools/fileDataProvider.cpp | 173 +++++++++++++++++++++++++++++++++++++++++++++ tools/fileDataProvider.h | 91 ++---------------------- train.cpp | 147 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 408 insertions(+), 223 deletions(-) delete mode 100644 main.cpp create mode 100644 split.cpp create mode 100644 tools/fileDataProvider.cpp create mode 100644 train.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6920572..180456c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,9 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED ) INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) -set(SOURCE_FILES main.cpp ./model/ranksvm.cpp ./model/ranksvmtn.cpp ./model/rankaccu.cpp) -add_executable(ranksvm ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/easylogging++.h tools/matrixIO.h tools/fileDataProvider.h) -TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) \ No newline at end of file +set(SOURCE_FILES model/ranksvm.cpp model/ranksvmtn.cpp model/rankaccu.cpp tools/fileDataProvider.cpp) +add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h) +add_executable(split split.cpp ${SOURCE_FILES}) +add_dependencies(ranksvm split) +TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) +TARGET_LINK_LIBRARIES( split ${Boost_LIBRARIES}) \ No newline at end of file diff --git a/main.cpp b/main.cpp deleted file mode 100644 index e8666f8..0000000 --- a/main.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include -#include -#include -#include -#include "tools/easylogging++.h" -#include "model/ranksvmtn.h" -#include "tools/fileDataProvider.h" -#include "model/rankaccu.h" - -INITIALIZE_EASYLOGGINGPP - -using namespace Eigen; -using namespace std; -namespace po = boost::program_options; - -po::variables_map vm; - -typedef int (*mainFunc)(DataProvider &dp); - -int train(DataProvider &dp) { - RSVM *rsvm; - rsvm = RSVM::loadModel(vm["model"].as()); - - dp.open(); - DataList D; - - LOG(INFO)<<"Training started"; - dp.getAllDataSet(D); - LOG(INFO)<<"Read "<train(D); - vector L; - rsvm->predict(D,L); - - LOG(INFO)<<"Training finished,saving model"; - - dp.close(); - rsvm->saveModel(vm["output"].as().c_str()); - delete rsvm; - return 0; -} - -int predict(DataProvider &dp) { - RSVM *rsvm; - rsvm = RSVM::loadModel(vm["model"].as().c_str()); - - dp.open(); - DataList D; - vector L; - CMC cmc; - LOG(INFO)<<"Prediction started"; - - ofstream fout; - - ostream* ot; - - if (vm.count("output")) { - fout.open(vm["output"].as().c_str()); - ot=&fout; - } - else - ot=&cout; - - while (!dp.EOFile()) - { - dp.getDataSet(D); - LOG(INFO)<<"Read "<predict(D,L); - - if (vm.count("validate")) - { - rank_accu(D,L); - if (vm.count("cmc")) - rank_CMC(D,L,cmc); - } - - if (vm.count("output") || !vm.count("validate")) - for (int i=0; i cur = cmc.getAcc(); - for (int i = 0;i(), "set input model file") - ("output,o", po::value(), "set output model/prediction file") - ("feature,i", po::value(), "set input feature file"); - - // Parsing program options - po::store(po::parse_command_line(argc, argv, desc), vm); - po::notify(vm); - - // Print help if necessary - if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) { - cout << desc; - return 0; - } - mainFunc mainf; - if (vm.count("train")) { - mainf = &train; - } - else if (vm.count("validate")||vm.count("predict")) { - mainf = &predict; - } - else return 0; - DataProvider* dp; - if (vm["feature"].as().find(".rid") == string::npos) - dp = new FileDP(vm["feature"].as()); - else - dp = new RidFileDP(vm["feature"].as()); - mainf(*dp); - delete dp; - return 0; -} diff --git a/split.cpp b/split.cpp new file mode 100644 index 0000000..be80545 --- /dev/null +++ b/split.cpp @@ -0,0 +1,76 @@ +// +// Created by joe on 5/13/15. +// + +#include +#include +#include "tools/dataProvider.h" +#include "tools/fileDataProvider.h" +#include +#include + +INITIALIZE_EASYLOGGINGPP + +using namespace std; +namespace po = boost::program_options; + +po::variables_map vm; + +int outputRid(vector a,int fsize,string fname) +{ + ofstream fout(fname.c_str()); + fout<qid; + for (int j=0;jfeature(j); + fout<(), "take number") + ("take,a", po::value(), "set output rid file 1(taken)") + ("left,b", po::value(), "set output rid file 2(left)") + ("input,i", po::value(), "set input Rid file"); + + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + // Print help if necessary + if (vm.count("help")) { + cout << desc; + return 0; + } + + if (vm.count("query")){ + RidFileDP dp(vm["input"].as().c_str()); + dp.open(); + cout<().c_str()); + vector a; + vector b; + dp.open(); + dp.take(vm["count"].as(),a,b); + outputRid(a,dp.getfSize(),vm["take"].as()); + outputRid(b,dp.getfSize(),vm["left"].as()); + dp.close(); + return 0; +} \ No newline at end of file diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp new file mode 100644 index 0000000..e9b7f3d --- /dev/null +++ b/tools/fileDataProvider.cpp @@ -0,0 +1,173 @@ +// +// Created by joe on 5/13/15. +// + +#include "fileDataProvider.h" +#include +#include + +using namespace std; + +mt19937 gen; + +int FileDP::getDataSet(DataList &out){ + DataEntry* e; + out.clear(); + int fsize; + fin>>fsize; + LOG(INFO)<<"Feature size:"<>e->rank; + if (e->rank == 0) + { + delete e; + break; + } + fin>>e->qid; + e->feature.resize(fsize); + for (int i=0;i>e->feature(i); + } + out.addEntry(e); + } + eof=true; + return 0; +} + +void RidFileDP::readEntries() { + DataEntry *e; + int fsize; + d.clear(); + fin >> fsize; + LOG(INFO) << "Feature size:" << fsize; + d.setfSize(fsize); + while (!fin.eof()) { + e = new DataEntry; + fin >> e->qid; + if (e->qid == "0") { + delete e; + break; + } + e->feature.resize(fsize); + e->rank=-1; + for (int i = 0; i < fsize; ++i) { + fin >> e->feature(i); + } + d.addEntry(e); + } + pos = 0; + qid = 1; + read = true; +} + +int RidFileDP::getDataSet(DataList &out){ + DataEntry *e; + int fsize; + if (!read) + readEntries(); + out.clear(); + fsize = d.getfSize(); + out.setfSize(fsize); + std::vector & dat = d.getData(); + for (int i=0;iqid == dat[pos]->qid) + { + e = new DataEntry; + e->rank=1; + dat[i]->rank=qid; + } + else + { + e = new DataEntry; + e->rank=-1; + } + e->feature.resize(d.getfSize()); + e->qid=dat[pos]->qid; + for (int j = 0; j < fsize; ++j) { + e->feature(j) = fabs(dat[i]->feature(j) -dat[pos]->feature(j)); + } + out.addEntry(e); + } + dat[pos]->qid=std::to_string(qid); + ++qid; + dat[pos]->rank=qid; + while (posrank!=-1) + ++pos; + if (pos==d.getSize()) + eof = true; + return 0; +} + +int RidFileDP::getpSize() { + std::vector p; + if (!read) + readEntries(); + std::vector &dat = d.getData(); + for (int i=0;iqid ) + { + ext=true; + break; + } + if (!ext) + p.push_back(dat[i]->qid); + } + return p.size(); +}; + +void scrambler(vector &dat) +{ + DataEntry* e; + int sz=(int)dat.size(); + for (int i=0;i &a,vector &b) +{ + gen.seed(time(NULL)); + DataEntry *e; + if (!read) + readEntries(); + vector tmp; + tmp.reserve(d.getSize()); + a.clear(); + b.clear(); + std::vector &dat = d.getData(); + scrambler(tmp); + for (int i=0;iqid; + a.push_back(tmp[pos]); + tmp[pos]=NULL; + for (int j = pos+1; j< tmp.size();++j) + if (tmp[j]!=NULL &&tmp[j]->qid==qid) + { + a.push_back(tmp[j]); + tmp[j]=NULL; + } + } + for (int i=0;i>fsize; - LOG(INFO)<<"Feature size:"<>e->rank; - if (e->rank == 0) - { - delete e; - break; - } - fin>>e->qid; - e->feature.resize(fsize); - for (int i=0;i>e->feature(i); - } - out.addEntry(e); - } - eof=true; - return 0; - } + virtual int getDataSet(DataList &out); virtual int open(){fin.open(fname); eof=false;return 0;}; virtual int close(){fin.close();return 0;}; }; @@ -58,68 +34,13 @@ private: int qid; public: RidFileDP(std::string fn=""):fname(fn){read=false;}; - virtual int getDataSet(DataList &out){ - DataEntry *e; - int fsize; - if (!read) { - d.clear(); - fin >> fsize; - LOG(INFO) << "Feature size:" << fsize; - d.setfSize(fsize); - while (!fin.eof()) { - e = new DataEntry; - fin >> e->qid; - if (e->qid == "0") { - delete e; - break; - } - e->feature.resize(fsize); - e->rank=-1; - for (int i = 0; i < fsize; ++i) { - fin >> e->feature(i); - } - d.addEntry(e); - } - pos = 0; - qid = 1; - read = true; - } - out.clear(); - fsize = d.getfSize(); - out.setfSize(fsize); - std::vector & dat = d.getData(); - for (int i=0;iqid == dat[pos]->qid) - { - e = new DataEntry; - e->rank=1; - dat[i]->rank=qid; - } - else - { - e = new DataEntry; - e->rank=-1; - } - e->feature.resize(d.getfSize()); - e->qid=dat[pos]->qid; - for (int j = 0; j < fsize; ++j) { - e->feature(j) = fabs(dat[i]->feature(j) -dat[pos]->feature(j)); - } - out.addEntry(e); - } - dat[pos]->qid=std::to_string(qid); - ++qid; - dat[pos]->rank=qid; - while (posrank!=-1) - ++pos; - if (pos==d.getSize()) - eof = true; - return 0; - } + void readEntries(); + int getfSize() { if(!read) readEntries(); return d.getfSize();}; + int getpSize(); + virtual int getDataSet(DataList &out); virtual int open(){fin.open(fname); eof=false;return 0;}; virtual int close(){fin.close(); d.clear();return 0;}; + void take(int n,std::vector &a,std::vector &b); }; #endif \ No newline at end of file diff --git a/train.cpp b/train.cpp new file mode 100644 index 0000000..a0c62a9 --- /dev/null +++ b/train.cpp @@ -0,0 +1,147 @@ +#include +#include +#include +#include +#include "tools/easylogging++.h" +#include "model/ranksvmtn.h" +#include "tools/fileDataProvider.h" +#include "model/rankaccu.h" + +INITIALIZE_EASYLOGGINGPP + +using namespace Eigen; +using namespace std; +namespace po = boost::program_options; + +po::variables_map vm; + +typedef int (*mainFunc)(DataProvider &dp); + +int train(DataProvider &dp) { + RSVM *rsvm; + rsvm = RSVM::loadModel(vm["model"].as()); + + dp.open(); + DataList D; + + LOG(INFO)<<"Training started"; + dp.getAllDataSet(D); + LOG(INFO)<<"Read "<train(D); + vector L; + rsvm->predict(D,L); + + LOG(INFO)<<"Training finished,saving model"; + + dp.close(); + rsvm->saveModel(vm["output"].as().c_str()); + delete rsvm; + return 0; +} + +int predict(DataProvider &dp) { + RSVM *rsvm; + rsvm = RSVM::loadModel(vm["model"].as().c_str()); + + dp.open(); + DataList D; + vector L; + CMC cmc; + LOG(INFO)<<"Prediction started"; + + ofstream fout; + + ostream* ot; + + if (vm.count("output")) { + fout.open(vm["output"].as().c_str()); + ot=&fout; + } + else + ot=&cout; + + while (!dp.EOFile()) + { + dp.getDataSet(D); + LOG(INFO)<<"Read "<predict(D,L); + + if (vm.count("validate")) + { + rank_accu(D,L); + if (vm.count("cmc")) + rank_CMC(D,L,cmc); + } + + if (vm.count("output") || !vm.count("validate")) + for (int i=0; i cur = cmc.getAcc(); + for (int i = 0;i(), "set input model file") + ("output,o", po::value(), "set output model/prediction file") + ("feature,i", po::value(), "set input feature file"); + + // Parsing program options + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + // Print help if necessary + if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) { + cout << desc; + return 0; + } + + if (!vm.count("debug")) + defaultConf.setGlobally(el::ConfigurationType::Enabled, "false"); + // default logger uses default configurations + el::Loggers::reconfigureLogger("default", defaultConf); + + mainFunc mainf; + if (vm.count("train")) { + mainf = &train; + } + else if (vm.count("validate")||vm.count("predict")) { + mainf = &predict; + } + else return 0; + DataProvider* dp; + if (vm["feature"].as().find(".rid") == string::npos) + dp = new FileDP(vm["feature"].as()); + else + dp = new RidFileDP(vm["feature"].as()); + mainf(*dp); + delete dp; + return 0; +} -- cgit v1.2.3-70-g09d2