diff options
| -rw-r--r-- | CMakeLists.txt | 8 | ||||
| -rw-r--r-- | main.cpp | 40 | ||||
| -rw-r--r-- | model/ranksvm.cpp | 18 | ||||
| -rw-r--r-- | model/ranksvm.h | 10 | ||||
| -rw-r--r-- | model/ranksvmtn.h | 4 | ||||
| -rw-r--r-- | tools/dataProvider.h | 4 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 12 | ||||
| -rw-r--r-- | tools/matrixIO.h | 30 | 
8 files changed, 101 insertions, 25 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index a5d63a4..99c798e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,5 +8,9 @@ INCLUDE_DIRECTORIES ( "/usr/include/eigen3" )  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") -set(SOURCE_FILES main.cpp) -add_executable(eigenTest ${SOURCE_FILES})
\ No newline at end of file +FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED ) +INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) + +set(SOURCE_FILES main.cpp ./model/ranksvm.cpp) +add_executable(eigenTest ${SOURCE_FILES}) +TARGET_LINK_LIBRARIES( eigenTest ${Boost_LIBRARIES} )
\ No newline at end of file @@ -1,18 +1,54 @@  #include <iostream>  #include <Eigen/Dense> +#include <boost/program_options.hpp> +#include <string>  #include "tools/easylogging++.h" +#include "model/ranksvmtn.h" +#include "tools/fileDataProvider.h"  INITIALIZE_EASYLOGGINGPP  using Eigen::MatrixXd; +namespace po = boost::program_options; -int main() +po::variables_map vm; + +int train() +{ +  RSVM* rsvm; +  rsvm=RSVM::loadModel(vm["model"].as<std::string>().c_str()); +  FileDP dp(vm["feature"].as<std::string>().c_str()); +  rsvm->train(dp); +  rsvm->predict(dp); +  rsvm->saveModel(vm["output"].as<std::string>().c_str()); +  delete rsvm; +  return 0; +} + +int main(int argc,char ** argv)  { +  // Defining program options +  po::options_description desc("Allowed options"); +  desc.add_options() +          ("help,h", "produce help message") +          ("model,m", po::value<std::string>()->default_value("test.m"), "set input model file") +          ("feature,i", po::value<std::string>()->default_value("test.f"), "set input feature file") +          ("output,o", po::value<std::string>()->default_value("test.m.out"), "set output model file") +          ; + +  // Parsing program options +  po::store(po::parse_command_line(argc, argv, desc), vm); +  po::notify(vm); + +  if (vm.count("help")) { +    std::cout << desc << "\n"; +    return 1; +  } +    MatrixXd m(2,2);    m(0,0) = 3;    m(1,0) = 2.5;    m(0,1) = -1;    m(1,1) = m(1,0) + m(0,1); -  LOG(FATAL) << "My first info log using default logger";    std::cout << m << std::endl;  } diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index b15d2ef..58a097a 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -1,21 +1,23 @@  #include"ranksvm.h" -#include"ranksvmtron.h" +#include"ranksvmtn.h" +#include"../tools/matrixIO.h"  #include<iostream>  #include<fstream>  #include<string>  using namespace Eigen; +using namespace std; -int RSVM::saveModel(string fname){ +int RSVM::saveModel(const string fname){ -    std::ofstream fout(fname); +    std::ofstream fout(fname.c_str());      fout<<this->getName()<<endl;      fout<<this->model;      return 0;  } -static RSVM* RSVM::loadModel(string fname){ -    std::ifstream fin(fname); +RSVM* RSVM::loadModel(const string fname){ +    std::ifstream fin(fname.c_str());      std::string type;      int fsize;      fin>>type; @@ -25,17 +27,17 @@ static RSVM* RSVM::loadModel(string fname){      // TODO multiplex type      if (type=="TN") -        RSVM = new RSVMTN(); +        rsvm = new RSVMTN();      rsvm->fsize=fsize;      VectorXd model; -    fin>>model; +    Eigen::read_stream(fin, model);      rsvm->setModel(model);      return rsvm;  } -int RSVM::setModel(Eigen::VectorXd model) { +int RSVM::setModel(const Eigen::VectorXd &model) {      if (model.cols()!=fsize)          LOG(FATAL) << "Feature size mismatch";;      this->model=model; diff --git a/model/ranksvm.h b/model/ranksvm.h index 8993b87..fad790d 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -12,14 +12,14 @@ protected:      Eigen::VectorXd model;      int fsize;  public: -    virtual int train(DataProvider D)=0; -    virtual int predict(DataProvider D); -    int saveModel(std::string fname); -    static RSVM loadModel(std::string fname); +    virtual int train(DataProvider &D)=0; +    virtual int predict(DataProvider &D)=0; +    int saveModel(const std::string fname); +    static RSVM* loadModel(const std::string fname);      virtual std::string getName()=0;      Eigen::MatrixXd getModel(){          return model;}; -    void setModel(Eigen::VectorXd model); +    int setModel(const Eigen::VectorXd &model);  };  #endif
\ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 2a8f524..4a0fb16 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -13,8 +13,8 @@ public:          return "TN";      }; -    int train(DataProvider D){return 0;}; -    int predict(DataProvider D){return 0;}; +    int train(DataProvider &D){return 0;}; +    int predict(DataProvider &D){return 0;};  };  #endif
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index d598f3f..b2384c9 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -19,8 +19,8 @@ public:      virtual Eigen::MatrixXd* getAttr() = 0;      virtual Eigen::VectorXd* getPref() = 0; -    virtual int open(); -    virtual bool EOFile(); +    virtual int open()=0; +    virtual bool EOFile()=0;  };  #endif
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 3cfb3f9..7937866 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -11,10 +11,14 @@ private:  public:      FileDP(){};      FileDP(std::string fn):fname(fn){}; -    virtual Eigen::MatrixXd* getNextAttr(); -    virtual Eigen::VectorXd* getNextPref(); -    virtual int open(); -    virtual bool EOFile(); +    virtual Eigen::MatrixXd* getAttr(){ +        return new Eigen::MatrixXd(3,3); +    } +    virtual Eigen::VectorXd* getPref(){ +        return new Eigen::VectorXd(3); +    }; +    virtual int open(){return 0;}; +    virtual bool EOFile(){return true;};  };  #endif
\ No newline at end of file diff --git a/tools/matrixIO.h b/tools/matrixIO.h new file mode 100644 index 0000000..88cd419 --- /dev/null +++ b/tools/matrixIO.h @@ -0,0 +1,30 @@ +#ifndef MATIO_H +#define MATIO_H + +#include<iostream> + +namespace Eigen{ +    template<class Matrix> +    void write_stream(std::ostream &ostr, const Matrix& matrix){ +        typename Matrix::Index rows=matrix.rows(), cols=matrix.cols(); +        ostr<<rows<<" "<<cols<<std::endl; +        for (int r=0;r<rows;++r) +        { +            for (int c=0;c<cols;++c) +                ostr<<matrix(r,c)<<" "; +            ostr<<std::endl; +        } +    } +    template<class Matrix> +    void read_stream(std::istream &istr, Matrix& matrix){ +        typename Matrix::Index rows=0, cols=0; +        istr>>rows>>cols; +        matrix.resize(rows, cols); +        for (int r=0;r<rows;++r) +            for (int c=0;c<cols;++c) +                istr>>matrix(r,c); +    } +} // Eigen:: + + +#endif
\ No newline at end of file  | 
