diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-03-08 16:02:15 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-03-08 16:02:15 +0800 |
commit | e500bb4cdb32b13cc022b6dc5d221de7ad97a73e (patch) | |
tree | 643bcf86336437ccee6182fc6d19c92f33b7201a | |
parent | 457024eedfaf6e08146038c8cb3034e590a81df6 (diff) | |
download | ranksvm-e500bb4cdb32b13cc022b6dc5d221de7ad97a73e.tar.gz ranksvm-e500bb4cdb32b13cc022b6dc5d221de7ad97a73e.tar.bz2 ranksvm-e500bb4cdb32b13cc022b6dc5d221de7ad97a73e.zip |
added commandline parser
-rw-r--r-- | CMakeLists.txt | 8 | ||||
-rw-r--r-- | main.cpp | 40 | ||||
-rw-r--r-- | model/ranksvm.cpp | 18 | ||||
-rw-r--r-- | model/ranksvm.h | 10 | ||||
-rw-r--r-- | model/ranksvmtn.h | 4 | ||||
-rw-r--r-- | tools/dataProvider.h | 4 | ||||
-rw-r--r-- | tools/fileDataProvider.h | 12 | ||||
-rw-r--r-- | tools/matrixIO.h | 30 |
8 files changed, 101 insertions, 25 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index a5d63a4..99c798e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,5 +8,9 @@ INCLUDE_DIRECTORIES ( "/usr/include/eigen3" ) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") -set(SOURCE_FILES main.cpp) -add_executable(eigenTest ${SOURCE_FILES})
\ No newline at end of file +FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED ) +INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) + +set(SOURCE_FILES main.cpp ./model/ranksvm.cpp) +add_executable(eigenTest ${SOURCE_FILES}) +TARGET_LINK_LIBRARIES( eigenTest ${Boost_LIBRARIES} )
\ No newline at end of file @@ -1,18 +1,54 @@ #include <iostream> #include <Eigen/Dense> +#include <boost/program_options.hpp> +#include <string> #include "tools/easylogging++.h" +#include "model/ranksvmtn.h" +#include "tools/fileDataProvider.h" INITIALIZE_EASYLOGGINGPP using Eigen::MatrixXd; +namespace po = boost::program_options; -int main() +po::variables_map vm; + +int train() +{ + RSVM* rsvm; + rsvm=RSVM::loadModel(vm["model"].as<std::string>().c_str()); + FileDP dp(vm["feature"].as<std::string>().c_str()); + rsvm->train(dp); + rsvm->predict(dp); + rsvm->saveModel(vm["output"].as<std::string>().c_str()); + delete rsvm; + return 0; +} + +int main(int argc,char ** argv) { + // Defining program options + po::options_description desc("Allowed options"); + desc.add_options() + ("help,h", "produce help message") + ("model,m", po::value<std::string>()->default_value("test.m"), "set input model file") + ("feature,i", po::value<std::string>()->default_value("test.f"), "set input feature file") + ("output,o", po::value<std::string>()->default_value("test.m.out"), "set output model file") + ; + + // Parsing program options + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + if (vm.count("help")) { + std::cout << desc << "\n"; + return 1; + } + MatrixXd m(2,2); m(0,0) = 3; m(1,0) = 2.5; m(0,1) = -1; m(1,1) = m(1,0) + m(0,1); - LOG(FATAL) << "My first info log using default logger"; std::cout << m << std::endl; } diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index b15d2ef..58a097a 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -1,21 +1,23 @@ #include"ranksvm.h" -#include"ranksvmtron.h" +#include"ranksvmtn.h" +#include"../tools/matrixIO.h" #include<iostream> #include<fstream> #include<string> using namespace Eigen; +using namespace std; -int RSVM::saveModel(string fname){ +int RSVM::saveModel(const string fname){ - std::ofstream fout(fname); + std::ofstream fout(fname.c_str()); fout<<this->getName()<<endl; fout<<this->model; return 0; } -static RSVM* RSVM::loadModel(string fname){ - std::ifstream fin(fname); +RSVM* RSVM::loadModel(const string fname){ + std::ifstream fin(fname.c_str()); std::string type; int fsize; fin>>type; @@ -25,17 +27,17 @@ static RSVM* RSVM::loadModel(string fname){ // TODO multiplex type if (type=="TN") - RSVM = new RSVMTN(); + rsvm = new RSVMTN(); rsvm->fsize=fsize; VectorXd model; - fin>>model; + Eigen::read_stream(fin, model); rsvm->setModel(model); return rsvm; } -int RSVM::setModel(Eigen::VectorXd model) { +int RSVM::setModel(const Eigen::VectorXd &model) { if (model.cols()!=fsize) LOG(FATAL) << "Feature size mismatch";; this->model=model; diff --git a/model/ranksvm.h b/model/ranksvm.h index 8993b87..fad790d 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -12,14 +12,14 @@ protected: Eigen::VectorXd model; int fsize; public: - virtual int train(DataProvider D)=0; - virtual int predict(DataProvider D); - int saveModel(std::string fname); - static RSVM loadModel(std::string fname); + virtual int train(DataProvider &D)=0; + virtual int predict(DataProvider &D)=0; + int saveModel(const std::string fname); + static RSVM* loadModel(const std::string fname); virtual std::string getName()=0; Eigen::MatrixXd getModel(){ return model;}; - void setModel(Eigen::VectorXd model); + int setModel(const Eigen::VectorXd &model); }; #endif
\ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 2a8f524..4a0fb16 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -13,8 +13,8 @@ public: return "TN"; }; - int train(DataProvider D){return 0;}; - int predict(DataProvider D){return 0;}; + int train(DataProvider &D){return 0;}; + int predict(DataProvider &D){return 0;}; }; #endif
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index d598f3f..b2384c9 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -19,8 +19,8 @@ public: virtual Eigen::MatrixXd* getAttr() = 0; virtual Eigen::VectorXd* getPref() = 0; - virtual int open(); - virtual bool EOFile(); + virtual int open()=0; + virtual bool EOFile()=0; }; #endif
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 3cfb3f9..7937866 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -11,10 +11,14 @@ private: public: FileDP(){}; FileDP(std::string fn):fname(fn){}; - virtual Eigen::MatrixXd* getNextAttr(); - virtual Eigen::VectorXd* getNextPref(); - virtual int open(); - virtual bool EOFile(); + virtual Eigen::MatrixXd* getAttr(){ + return new Eigen::MatrixXd(3,3); + } + virtual Eigen::VectorXd* getPref(){ + return new Eigen::VectorXd(3); + }; + virtual int open(){return 0;}; + virtual bool EOFile(){return true;}; }; #endif
\ No newline at end of file diff --git a/tools/matrixIO.h b/tools/matrixIO.h new file mode 100644 index 0000000..88cd419 --- /dev/null +++ b/tools/matrixIO.h @@ -0,0 +1,30 @@ +#ifndef MATIO_H +#define MATIO_H + +#include<iostream> + +namespace Eigen{ + template<class Matrix> + void write_stream(std::ostream &ostr, const Matrix& matrix){ + typename Matrix::Index rows=matrix.rows(), cols=matrix.cols(); + ostr<<rows<<" "<<cols<<std::endl; + for (int r=0;r<rows;++r) + { + for (int c=0;c<cols;++c) + ostr<<matrix(r,c)<<" "; + ostr<<std::endl; + } + } + template<class Matrix> + void read_stream(std::istream &istr, Matrix& matrix){ + typename Matrix::Index rows=0, cols=0; + istr>>rows>>cols; + matrix.resize(rows, cols); + for (int r=0;r<rows;++r) + for (int c=0;c<cols;++c) + istr>>matrix(r,c); + } +} // Eigen:: + + +#endif
\ No newline at end of file |