diff options
| -rw-r--r-- | CMakeLists.txt | 6 | ||||
| -rw-r--r-- | main.cpp | 88 | ||||
| -rw-r--r-- | model/ranksvm.cpp | 3 | ||||
| -rw-r--r-- | model/ranksvm.h | 4 | ||||
| -rw-r--r-- | tools/dataProvider.h | 2 | 
5 files changed, 60 insertions, 43 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 99c798e..6ccf0c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@  cmake_minimum_required(VERSION 2.8.4) -project(eigenTest) +project(ranksvm)  # Use Eigen3 Library for Linear Algebra  INCLUDE_DIRECTORIES ( "/usr/include/eigen3" ) @@ -12,5 +12,5 @@ FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED )  INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR})  set(SOURCE_FILES main.cpp ./model/ranksvm.cpp) -add_executable(eigenTest ${SOURCE_FILES}) -TARGET_LINK_LIBRARIES( eigenTest ${Boost_LIBRARIES} )
\ No newline at end of file +add_executable(ranksvm ${SOURCE_FILES}) +TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} )
\ No newline at end of file @@ -13,42 +13,58 @@ namespace po = boost::program_options;  po::variables_map vm; -int train() -{ -  RSVM* rsvm; -  rsvm=RSVM::loadModel(vm["model"].as<std::string>().c_str()); -  FileDP dp(vm["feature"].as<std::string>().c_str()); -  rsvm->train(dp); -  rsvm->predict(dp); -  rsvm->saveModel(vm["output"].as<std::string>().c_str()); -  delete rsvm; -  return 0; +int train() { +    RSVM *rsvm; +    rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str()); +    FileDP dp(vm["feature"].as<std::string>().c_str()); +    rsvm->train(dp); +    rsvm->saveModel(vm["output"].as<std::string>().c_str()); +    delete rsvm; +    return 0;  } -int main(int argc,char ** argv) -{ -  // Defining program options -  po::options_description desc("Allowed options"); -  desc.add_options() -          ("help,h", "produce help message") -          ("model,m", po::value<std::string>()->default_value("test.m"), "set input model file") -          ("feature,i", po::value<std::string>()->default_value("test.f"), "set input feature file") -          ("output,o", po::value<std::string>()->default_value("test.m.out"), "set output model file") -          ; - -  // Parsing program options -  po::store(po::parse_command_line(argc, argv, desc), vm); -  po::notify(vm); - -  if (vm.count("help")) { -    std::cout << desc << "\n"; -    return 1; -  } - -  MatrixXd m(2,2); -  m(0,0) = 3; -  m(1,0) = 2.5; -  m(0,1) = -1; -  m(1,1) = m(1,0) + m(0,1); -  std::cout << m << std::endl; +int predict() { +    RSVM *rsvm; +    rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str()); +    FileDP dp(vm["feature"].as<std::string>().c_str()); +    rsvm->predict(dp); +    delete rsvm; +    return 0; +} + +int main(int argc, char **argv) { +    // Defining program options +    po::options_description desc("Allowed options"); +    desc.add_options() +            ("help,h", "produce help message") +            ("train,T", "training model") +            ("validate,V", "validate model") +            ("predict,P", "use model for prediction") +            ("model,m", po::value<std::string>(), "set input model file") +            ("output,o", po::value<std::string>(), "set output model file") +            ("feature,i", po::value<std::string>(), "set input feature file"); + +    // Parsing program options +    po::store(po::parse_command_line(argc, argv, desc), vm); +    po::notify(vm); + +    // Print help if necessary +    if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) { +        std::cout << desc; +        return 0; +    } + +    if (vm.count("train")) { +        LOG(INFO) << "Program option: training"; +        train(); +    } +    else if (vm.count("validate")) { +        LOG(INFO) << "Program option: validate"; +        predict(); +    } +    else if (vm.count("predict")) { +        LOG(INFO) << "Program option: predict"; +        predict(); +    } +    return 0;  } diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 58a097a..6294245 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -25,7 +25,6 @@ RSVM* RSVM::loadModel(const string fname){      RSVM* rsvm; -    // TODO multiplex type      if (type=="TN")          rsvm = new RSVMTN(); @@ -39,7 +38,7 @@ RSVM* RSVM::loadModel(const string fname){  int RSVM::setModel(const Eigen::VectorXd &model) {      if (model.cols()!=fsize) -        LOG(FATAL) << "Feature size mismatch";; +        LOG(FATAL) << "Feature size mismatch";      this->model=model;      return 0;  }
\ No newline at end of file diff --git a/model/ranksvm.h b/model/ranksvm.h index fad790d..e7b7c4a 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -12,8 +12,8 @@ protected:      Eigen::VectorXd model;      int fsize;  public: -    virtual int train(DataProvider &D)=0; -    virtual int predict(DataProvider &D)=0; +    virtual int train(DataProvider &D)=0; // Dataprovider will have to provide label +    virtual int predict(DataProvider &D)=0;  // TODO Not sure how to construct this      int saveModel(const std::string fname);      static RSVM* loadModel(const std::string fname);      virtual std::string getName()=0; diff --git a/tools/dataProvider.h b/tools/dataProvider.h index b2384c9..d9440ce 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -4,6 +4,8 @@  #include<Eigen/Dense>  #include "../tools/easylogging++.h" +// TODO decide how to construct training data +  class DataProvider  //Virtual base class for data input  {  protected:  | 
