diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2015-03-11 00:55:41 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2015-03-11 00:55:41 +0800 | 
| commit | f6c22b46449fa77f90e319e4b159ccb6c2a5732b (patch) | |
| tree | bc38dde03ffa5cb3c2fea9f5fefff0b990de405b | |
| parent | 3d204f5fe4614624ca342090feecbfe4df188d9d (diff) | |
| download | ranksvm-f6c22b46449fa77f90e319e4b159ccb6c2a5732b.tar.gz ranksvm-f6c22b46449fa77f90e319e4b159ccb6c2a5732b.tar.bz2 ranksvm-f6c22b46449fa77f90e319e4b159ccb6c2a5732b.zip | |
restructure, changed label type
| -rw-r--r-- | main.cpp | 17 | ||||
| -rw-r--r-- | model/ranksvm.cpp | 2 | ||||
| -rw-r--r-- | model/ranksvm.h | 4 | ||||
| -rw-r--r-- | model/ranksvmtn.cpp | 2 | ||||
| -rw-r--r-- | model/ranksvmtn.h | 2 | ||||
| -rw-r--r-- | tools/dataProvider.h | 3 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 1 | 
7 files changed, 17 insertions, 14 deletions
| @@ -5,6 +5,7 @@  #include "tools/easylogging++.h"  #include "model/ranksvmtn.h"  #include "tools/fileDataProvider.h" +#include "tools/matrixIO.h"  INITIALIZE_EASYLOGGINGPP @@ -22,17 +23,15 @@ int train() {      dp.open();      DataSet D;      Labels L; +      LOG(INFO)<<"Training started"; -    while (!dp.EOFile()) -    { -        dp.getDataSet(D); -        dp.getLabel(L); -        rsvm->train(D,L); -    } +    dp.getDataSet(D); +    dp.getLabel(L); +    rsvm->train(D,L);      LOG(INFO)<<"Training finished,saving model"; - +    dp.close();      rsvm->saveModel(vm["output"].as<std::string>().c_str());      delete rsvm;      return 0; @@ -43,12 +42,14 @@ int predict() {      rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());      FileDP dp(vm["feature"].as<std::string>().c_str());      DataSet D; -    MatrixXd L; +    Labels L;      while (!dp.EOFile())      {          dp.getDataSet(D);          rsvm->predict(D,L);      } + +    Eigen::write_stream(std::cout, L);      delete rsvm;      return 0;  } diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 060001b..628ef37 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -37,7 +37,7 @@ RSVM* RSVM::loadModel(const string fname){      return rsvm;  } -int RSVM::setModel(const Eigen::VectorXd &model) { +int RSVM::setModel(const Labels &model) {      if (model.rows()!=fsize)          LOG(FATAL) << "Feature size mismatch: "<<fsize<<" "<<model.cols();      this->model=model; diff --git a/model/ranksvm.h b/model/ranksvm.h index 21fb30b..5217b56 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -13,7 +13,7 @@ protected:      int fsize;  public:      virtual int train(DataSet &D, Labels &label)=0; -    virtual int predict(DataSet &D, Eigen::MatrixXd &res)=0; +    virtual int predict(DataSet &D, Labels &res)=0;      // TODO Not sure how to construct this      //  Possible solution: generate a nxn matrix each row contains the sorted list of ranker result.      int saveModel(const std::string fname); @@ -21,7 +21,7 @@ public:      virtual std::string getName()=0;      Eigen::MatrixXd getModel(){          return model;}; -    int setModel(const Eigen::VectorXd &model); +    int setModel(const Labels &model);  };  #endif
\ No newline at end of file diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index ef8d98c..746e967 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -7,6 +7,6 @@ int RSVMTN::train(DataSet &D, Labels &label){      return 0;  }; -int RSVMTN::predict(DataSet &D, MatrixXd &res){ +int RSVMTN::predict(DataSet &D, Labels &res){      return 0;  };
\ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 21b03bd..cdb9796 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -13,7 +13,7 @@ public:          return "TN";      };      virtual int train(DataSet &D, Labels &label); -    virtual int predict(DataSet &D, Eigen::MatrixXd &res); +    virtual int predict(DataSet &D, Labels &res);  };  #endif
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index ce2bf12..d311149 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -17,7 +17,7 @@  typedef Eigen::MatrixXd DataSet; -typedef std::vector<double> Labels; +typedef Eigen::VectorXd Labels;  class DataProvider  //Virtual base class for data input  { @@ -39,6 +39,7 @@ public:      virtual int getDataSet(DataSet &out) = 0;      virtual int getLabel(Labels &out) = 0;      virtual int open()=0; +    virtual int close()=0;  };  #endif
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 8a499ca..a4cc252 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -18,6 +18,7 @@ public:          return 0;      };      virtual int open(){eof=true;return 0;}; +    virtual int close(){return 0;};  };  #endif
\ No newline at end of file | 
