diff options
Diffstat (limited to 'train.cpp')
-rw-r--r-- | train.cpp | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/train.cpp b/train.cpp new file mode 100644 index 0000000..a0c62a9 --- /dev/null +++ b/train.cpp @@ -0,0 +1,147 @@ +#include <iostream> +#include <Eigen/Dense> +#include <boost/program_options.hpp> +#include <list> +#include "tools/easylogging++.h" +#include "model/ranksvmtn.h" +#include "tools/fileDataProvider.h" +#include "model/rankaccu.h" + +INITIALIZE_EASYLOGGINGPP + +using namespace Eigen; +using namespace std; +namespace po = boost::program_options; + +po::variables_map vm; + +typedef int (*mainFunc)(DataProvider &dp); + +int train(DataProvider &dp) { + RSVM *rsvm; + rsvm = RSVM::loadModel(vm["model"].as<string>()); + + dp.open(); + DataList D; + + LOG(INFO)<<"Training started"; + dp.getAllDataSet(D); + LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features"; + rsvm->train(D); + vector<double> L; + rsvm->predict(D,L); + + LOG(INFO)<<"Training finished,saving model"; + + dp.close(); + rsvm->saveModel(vm["output"].as<string>().c_str()); + delete rsvm; + return 0; +} + +int predict(DataProvider &dp) { + RSVM *rsvm; + rsvm = RSVM::loadModel(vm["model"].as<string>().c_str()); + + dp.open(); + DataList D; + vector<double> L; + CMC cmc; + LOG(INFO)<<"Prediction started"; + + ofstream fout; + + ostream* ot; + + if (vm.count("output")) { + fout.open(vm["output"].as<string>().c_str()); + ot=&fout; + } + else + ot=&cout; + + while (!dp.EOFile()) + { + dp.getDataSet(D); + LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features"; + rsvm->predict(D,L); + + if (vm.count("validate")) + { + rank_accu(D,L); + if (vm.count("cmc")) + rank_CMC(D,L,cmc); + } + + if (vm.count("output") || !vm.count("validate")) + for (int i=0; i<L.size();++i) + *ot<<L[i]<<endl; + } + + LOG(INFO)<<"Finished"; + if (vm.count("cmc")) + { + LOG(INFO)<< "CMC accounted over " <<cmc.getCount() << " queries"; + *ot << "CMC"<<endl; + vector<double> cur = cmc.getAcc(); + for (int i = 0;i<CMC_MAX;++i) + *ot << cur[i]<<endl; + } + if (vm.count("output")) + fout.close(); + dp.close(); + delete rsvm; + return 0; +} + +int main(int argc, char **argv) { + el::Configurations defaultConf; + defaultConf.setToDefault(); + // Values are always std::string + defaultConf.setGlobally(el::ConfigurationType::Format, "%datetime %level %msg"); + + // Defining program options + po::options_description desc("Allowed options"); + desc.add_options() + ("help,h", "produce help message") + ("train,T", "training model") + ("validate,V", "validate model") + ("predict,P", "use model for prediction") + ("cmc,C", "enable cmc auditing") + ("debug,d", "show debug messages") + ("model,m", po::value<string>(), "set input model file") + ("output,o", po::value<string>(), "set output model/prediction file") + ("feature,i", po::value<string>(), "set input feature file"); + + // Parsing program options + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + // Print help if necessary + if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) { + cout << desc; + return 0; + } + + if (!vm.count("debug")) + defaultConf.setGlobally(el::ConfigurationType::Enabled, "false"); + // default logger uses default configurations + el::Loggers::reconfigureLogger("default", defaultConf); + + mainFunc mainf; + if (vm.count("train")) { + mainf = &train; + } + else if (vm.count("validate")||vm.count("predict")) { + mainf = &predict; + } + else return 0; + DataProvider* dp; + if (vm["feature"].as<string>().find(".rid") == string::npos) + dp = new FileDP(vm["feature"].as<string>()); + else + dp = new RidFileDP(vm["feature"].as<string>()); + mainf(*dp); + delete dp; + return 0; +} |