summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt6
-rw-r--r--main.cpp88
-rw-r--r--model/ranksvm.cpp3
-rw-r--r--model/ranksvm.h4
-rw-r--r--tools/dataProvider.h2
5 files changed, 60 insertions, 43 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 99c798e..6ccf0c8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 2.8.4)
-project(eigenTest)
+project(ranksvm)
# Use Eigen3 Library for Linear Algebra
INCLUDE_DIRECTORIES ( "/usr/include/eigen3" )
@@ -12,5 +12,5 @@ FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED )
INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR})
set(SOURCE_FILES main.cpp ./model/ranksvm.cpp)
-add_executable(eigenTest ${SOURCE_FILES})
-TARGET_LINK_LIBRARIES( eigenTest ${Boost_LIBRARIES} ) \ No newline at end of file
+add_executable(ranksvm ${SOURCE_FILES})
+TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) \ No newline at end of file
diff --git a/main.cpp b/main.cpp
index d23ded2..8d5b393 100644
--- a/main.cpp
+++ b/main.cpp
@@ -13,42 +13,58 @@ namespace po = boost::program_options;
po::variables_map vm;
-int train()
-{
- RSVM* rsvm;
- rsvm=RSVM::loadModel(vm["model"].as<std::string>().c_str());
- FileDP dp(vm["feature"].as<std::string>().c_str());
- rsvm->train(dp);
- rsvm->predict(dp);
- rsvm->saveModel(vm["output"].as<std::string>().c_str());
- delete rsvm;
- return 0;
+int train() {
+ RSVM *rsvm;
+ rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());
+ FileDP dp(vm["feature"].as<std::string>().c_str());
+ rsvm->train(dp);
+ rsvm->saveModel(vm["output"].as<std::string>().c_str());
+ delete rsvm;
+ return 0;
}
-int main(int argc,char ** argv)
-{
- // Defining program options
- po::options_description desc("Allowed options");
- desc.add_options()
- ("help,h", "produce help message")
- ("model,m", po::value<std::string>()->default_value("test.m"), "set input model file")
- ("feature,i", po::value<std::string>()->default_value("test.f"), "set input feature file")
- ("output,o", po::value<std::string>()->default_value("test.m.out"), "set output model file")
- ;
-
- // Parsing program options
- po::store(po::parse_command_line(argc, argv, desc), vm);
- po::notify(vm);
-
- if (vm.count("help")) {
- std::cout << desc << "\n";
- return 1;
- }
-
- MatrixXd m(2,2);
- m(0,0) = 3;
- m(1,0) = 2.5;
- m(0,1) = -1;
- m(1,1) = m(1,0) + m(0,1);
- std::cout << m << std::endl;
+int predict() {
+ RSVM *rsvm;
+ rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());
+ FileDP dp(vm["feature"].as<std::string>().c_str());
+ rsvm->predict(dp);
+ delete rsvm;
+ return 0;
+}
+
+int main(int argc, char **argv) {
+ // Defining program options
+ po::options_description desc("Allowed options");
+ desc.add_options()
+ ("help,h", "produce help message")
+ ("train,T", "training model")
+ ("validate,V", "validate model")
+ ("predict,P", "use model for prediction")
+ ("model,m", po::value<std::string>(), "set input model file")
+ ("output,o", po::value<std::string>(), "set output model file")
+ ("feature,i", po::value<std::string>(), "set input feature file");
+
+ // Parsing program options
+ po::store(po::parse_command_line(argc, argv, desc), vm);
+ po::notify(vm);
+
+ // Print help if necessary
+ if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) {
+ std::cout << desc;
+ return 0;
+ }
+
+ if (vm.count("train")) {
+ LOG(INFO) << "Program option: training";
+ train();
+ }
+ else if (vm.count("validate")) {
+ LOG(INFO) << "Program option: validate";
+ predict();
+ }
+ else if (vm.count("predict")) {
+ LOG(INFO) << "Program option: predict";
+ predict();
+ }
+ return 0;
}
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp
index 58a097a..6294245 100644
--- a/model/ranksvm.cpp
+++ b/model/ranksvm.cpp
@@ -25,7 +25,6 @@ RSVM* RSVM::loadModel(const string fname){
RSVM* rsvm;
- // TODO multiplex type
if (type=="TN")
rsvm = new RSVMTN();
@@ -39,7 +38,7 @@ RSVM* RSVM::loadModel(const string fname){
int RSVM::setModel(const Eigen::VectorXd &model) {
if (model.cols()!=fsize)
- LOG(FATAL) << "Feature size mismatch";;
+ LOG(FATAL) << "Feature size mismatch";
this->model=model;
return 0;
} \ No newline at end of file
diff --git a/model/ranksvm.h b/model/ranksvm.h
index fad790d..e7b7c4a 100644
--- a/model/ranksvm.h
+++ b/model/ranksvm.h
@@ -12,8 +12,8 @@ protected:
Eigen::VectorXd model;
int fsize;
public:
- virtual int train(DataProvider &D)=0;
- virtual int predict(DataProvider &D)=0;
+ virtual int train(DataProvider &D)=0; // Dataprovider will have to provide label
+ virtual int predict(DataProvider &D)=0; // TODO Not sure how to construct this
int saveModel(const std::string fname);
static RSVM* loadModel(const std::string fname);
virtual std::string getName()=0;
diff --git a/tools/dataProvider.h b/tools/dataProvider.h
index b2384c9..d9440ce 100644
--- a/tools/dataProvider.h
+++ b/tools/dataProvider.h
@@ -4,6 +4,8 @@
#include<Eigen/Dense>
#include "../tools/easylogging++.h"
+// TODO decide how to construct training data
+
class DataProvider //Virtual base class for data input
{
protected: