From 20587ac550cfcb2d7b3d6ec16e46ba1a8d0af869 Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Wed, 13 May 2015 13:35:03 +0800 Subject: added split --- train.cpp | 147 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 train.cpp (limited to 'train.cpp') diff --git a/train.cpp b/train.cpp new file mode 100644 index 0000000..a0c62a9 --- /dev/null +++ b/train.cpp @@ -0,0 +1,147 @@ +#include +#include +#include +#include +#include "tools/easylogging++.h" +#include "model/ranksvmtn.h" +#include "tools/fileDataProvider.h" +#include "model/rankaccu.h" + +INITIALIZE_EASYLOGGINGPP + +using namespace Eigen; +using namespace std; +namespace po = boost::program_options; + +po::variables_map vm; + +typedef int (*mainFunc)(DataProvider &dp); + +int train(DataProvider &dp) { + RSVM *rsvm; + rsvm = RSVM::loadModel(vm["model"].as()); + + dp.open(); + DataList D; + + LOG(INFO)<<"Training started"; + dp.getAllDataSet(D); + LOG(INFO)<<"Read "<train(D); + vector L; + rsvm->predict(D,L); + + LOG(INFO)<<"Training finished,saving model"; + + dp.close(); + rsvm->saveModel(vm["output"].as().c_str()); + delete rsvm; + return 0; +} + +int predict(DataProvider &dp) { + RSVM *rsvm; + rsvm = RSVM::loadModel(vm["model"].as().c_str()); + + dp.open(); + DataList D; + vector L; + CMC cmc; + LOG(INFO)<<"Prediction started"; + + ofstream fout; + + ostream* ot; + + if (vm.count("output")) { + fout.open(vm["output"].as().c_str()); + ot=&fout; + } + else + ot=&cout; + + while (!dp.EOFile()) + { + dp.getDataSet(D); + LOG(INFO)<<"Read "<predict(D,L); + + if (vm.count("validate")) + { + rank_accu(D,L); + if (vm.count("cmc")) + rank_CMC(D,L,cmc); + } + + if (vm.count("output") || !vm.count("validate")) + for (int i=0; i cur = cmc.getAcc(); + for (int i = 0;i(), "set input model file") + ("output,o", po::value(), "set output model/prediction file") + ("feature,i", po::value(), "set input feature file"); + + // Parsing program options + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + // Print help if necessary + if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) { + cout << desc; + return 0; + } + + if (!vm.count("debug")) + defaultConf.setGlobally(el::ConfigurationType::Enabled, "false"); + // default logger uses default configurations + el::Loggers::reconfigureLogger("default", defaultConf); + + mainFunc mainf; + if (vm.count("train")) { + mainf = &train; + } + else if (vm.count("validate")||vm.count("predict")) { + mainf = &predict; + } + else return 0; + DataProvider* dp; + if (vm["feature"].as().find(".rid") == string::npos) + dp = new FileDP(vm["feature"].as()); + else + dp = new RidFileDP(vm["feature"].as()); + mainf(*dp); + delete dp; + return 0; +} -- cgit v1.2.3-70-g09d2