diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-03-08 16:42:53 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-03-08 16:42:53 +0800 |
commit | 22882d7113c13cb1e00c59b54050f16ac1b7cc30 (patch) | |
tree | b94a20e7834ce52352aa031892d295413b65a372 | |
parent | e500bb4cdb32b13cc022b6dc5d221de7ad97a73e (diff) | |
download | ranksvm-22882d7113c13cb1e00c59b54050f16ac1b7cc30.tar.gz ranksvm-22882d7113c13cb1e00c59b54050f16ac1b7cc30.tar.bz2 ranksvm-22882d7113c13cb1e00c59b54050f16ac1b7cc30.zip |
migrating & scaffolding & reformatting
Added more cmdline options
-rw-r--r-- | CMakeLists.txt | 6 | ||||
-rw-r--r-- | main.cpp | 88 | ||||
-rw-r--r-- | model/ranksvm.cpp | 3 | ||||
-rw-r--r-- | model/ranksvm.h | 4 | ||||
-rw-r--r-- | tools/dataProvider.h | 2 |
5 files changed, 60 insertions, 43 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 99c798e..6ccf0c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 2.8.4) -project(eigenTest) +project(ranksvm) # Use Eigen3 Library for Linear Algebra INCLUDE_DIRECTORIES ( "/usr/include/eigen3" ) @@ -12,5 +12,5 @@ FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED ) INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) set(SOURCE_FILES main.cpp ./model/ranksvm.cpp) -add_executable(eigenTest ${SOURCE_FILES}) -TARGET_LINK_LIBRARIES( eigenTest ${Boost_LIBRARIES} )
\ No newline at end of file +add_executable(ranksvm ${SOURCE_FILES}) +TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} )
\ No newline at end of file @@ -13,42 +13,58 @@ namespace po = boost::program_options; po::variables_map vm; -int train() -{ - RSVM* rsvm; - rsvm=RSVM::loadModel(vm["model"].as<std::string>().c_str()); - FileDP dp(vm["feature"].as<std::string>().c_str()); - rsvm->train(dp); - rsvm->predict(dp); - rsvm->saveModel(vm["output"].as<std::string>().c_str()); - delete rsvm; - return 0; +int train() { + RSVM *rsvm; + rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str()); + FileDP dp(vm["feature"].as<std::string>().c_str()); + rsvm->train(dp); + rsvm->saveModel(vm["output"].as<std::string>().c_str()); + delete rsvm; + return 0; } -int main(int argc,char ** argv) -{ - // Defining program options - po::options_description desc("Allowed options"); - desc.add_options() - ("help,h", "produce help message") - ("model,m", po::value<std::string>()->default_value("test.m"), "set input model file") - ("feature,i", po::value<std::string>()->default_value("test.f"), "set input feature file") - ("output,o", po::value<std::string>()->default_value("test.m.out"), "set output model file") - ; - - // Parsing program options - po::store(po::parse_command_line(argc, argv, desc), vm); - po::notify(vm); - - if (vm.count("help")) { - std::cout << desc << "\n"; - return 1; - } - - MatrixXd m(2,2); - m(0,0) = 3; - m(1,0) = 2.5; - m(0,1) = -1; - m(1,1) = m(1,0) + m(0,1); - std::cout << m << std::endl; +int predict() { + RSVM *rsvm; + rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str()); + FileDP dp(vm["feature"].as<std::string>().c_str()); + rsvm->predict(dp); + delete rsvm; + return 0; +} + +int main(int argc, char **argv) { + // Defining program options + po::options_description desc("Allowed options"); + desc.add_options() + ("help,h", "produce help message") + ("train,T", "training model") + ("validate,V", "validate model") + ("predict,P", "use model for prediction") + ("model,m", po::value<std::string>(), "set input model file") + ("output,o", po::value<std::string>(), "set output model file") + ("feature,i", po::value<std::string>(), "set input feature file"); + + // Parsing program options + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + // Print help if necessary + if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) { + std::cout << desc; + return 0; + } + + if (vm.count("train")) { + LOG(INFO) << "Program option: training"; + train(); + } + else if (vm.count("validate")) { + LOG(INFO) << "Program option: validate"; + predict(); + } + else if (vm.count("predict")) { + LOG(INFO) << "Program option: predict"; + predict(); + } + return 0; } diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 58a097a..6294245 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -25,7 +25,6 @@ RSVM* RSVM::loadModel(const string fname){ RSVM* rsvm; - // TODO multiplex type if (type=="TN") rsvm = new RSVMTN(); @@ -39,7 +38,7 @@ RSVM* RSVM::loadModel(const string fname){ int RSVM::setModel(const Eigen::VectorXd &model) { if (model.cols()!=fsize) - LOG(FATAL) << "Feature size mismatch";; + LOG(FATAL) << "Feature size mismatch"; this->model=model; return 0; }
\ No newline at end of file diff --git a/model/ranksvm.h b/model/ranksvm.h index fad790d..e7b7c4a 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -12,8 +12,8 @@ protected: Eigen::VectorXd model; int fsize; public: - virtual int train(DataProvider &D)=0; - virtual int predict(DataProvider &D)=0; + virtual int train(DataProvider &D)=0; // Dataprovider will have to provide label + virtual int predict(DataProvider &D)=0; // TODO Not sure how to construct this int saveModel(const std::string fname); static RSVM* loadModel(const std::string fname); virtual std::string getName()=0; diff --git a/tools/dataProvider.h b/tools/dataProvider.h index b2384c9..d9440ce 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -4,6 +4,8 @@ #include<Eigen/Dense> #include "../tools/easylogging++.h" +// TODO decide how to construct training data + class DataProvider //Virtual base class for data input { protected: |