diff options
| -rw-r--r-- | CMakeLists.txt | 2 | ||||
| -rw-r--r-- | main.cpp | 17 | ||||
| -rw-r--r-- | model/ranksvm.h | 6 | ||||
| -rw-r--r-- | model/ranksvmtn.cpp | 12 | ||||
| -rw-r--r-- | model/ranksvmtn.h | 5 | ||||
| -rw-r--r-- | tools/dataProvider.h | 17 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 15 | 
7 files changed, 57 insertions, 17 deletions
| diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ccf0c8..e1eb353 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,6 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")  FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED )  INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) -set(SOURCE_FILES main.cpp ./model/ranksvm.cpp) +set(SOURCE_FILES main.cpp ./model/ranksvm.cpp ./model/ranksvmtn.cpp)  add_executable(ranksvm ${SOURCE_FILES})  TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} )
\ No newline at end of file @@ -17,7 +17,14 @@ int train() {      RSVM *rsvm;      rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());      FileDP dp(vm["feature"].as<std::string>().c_str()); -    rsvm->train(dp); +    DataSet D; +    Labels L; +    while (!dp.EOFile()) +    { +        dp.getDataSet(D); +        dp.getLabel(L); +        rsvm->train(D,L); +    }      rsvm->saveModel(vm["output"].as<std::string>().c_str());      delete rsvm;      return 0; @@ -27,7 +34,13 @@ int predict() {      RSVM *rsvm;      rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());      FileDP dp(vm["feature"].as<std::string>().c_str()); -    rsvm->predict(dp); +    DataSet D; +    MatrixXd L; +    while (!dp.EOFile()) +    { +        dp.getDataSet(D); +        rsvm->predict(D,L); +    }      delete rsvm;      return 0;  } diff --git a/model/ranksvm.h b/model/ranksvm.h index e7b7c4a..21fb30b 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -12,8 +12,10 @@ protected:      Eigen::VectorXd model;      int fsize;  public: -    virtual int train(DataProvider &D)=0; // Dataprovider will have to provide label -    virtual int predict(DataProvider &D)=0;  // TODO Not sure how to construct this +    virtual int train(DataSet &D, Labels &label)=0; +    virtual int predict(DataSet &D, Eigen::MatrixXd &res)=0; +    // TODO Not sure how to construct this +    //  Possible solution: generate a nxn matrix each row contains the sorted list of ranker result.      int saveModel(const std::string fname);      static RSVM* loadModel(const std::string fname);      virtual std::string getName()=0; diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp new file mode 100644 index 0000000..ef8d98c --- /dev/null +++ b/model/ranksvmtn.cpp @@ -0,0 +1,12 @@ +#include "ranksvmtn.h" + +using namespace std; +using namespace Eigen; + +int RSVMTN::train(DataSet &D, Labels &label){ +    return 0; +}; + +int RSVMTN::predict(DataSet &D, MatrixXd &res){ +    return 0; +};
\ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 4a0fb16..21b03bd 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -12,9 +12,8 @@ public:      {          return "TN";      }; - -    int train(DataProvider &D){return 0;}; -    int predict(DataProvider &D){return 0;}; +    virtual int train(DataSet &D, Labels &label); +    virtual int predict(DataSet &D, Eigen::MatrixXd &res);  };  #endif
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index d9440ce..0e6ed9e 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -3,8 +3,21 @@  #include<Eigen/Dense>  #include "../tools/easylogging++.h" +#include<vector>  // TODO decide how to construct training data +// One possible way for training data: +//  Matrix composed of an array of feature vectors +//  Labels are composed of linked list, such as +//      6,3,4,0,5,0,0 +//      =>  0->6 | 1->3 | 2->4->5 +//  How to compensate for non exhaustive labeling? +//      Use -1 to indicate not yet labeled data +//      -1s will be excluded from training + +typedef Eigen::MatrixXd DataSet; + +typedef std::vector<double> Labels;  class DataProvider  //Virtual base class for data input  { @@ -19,8 +32,8 @@ public:          return attrSize;      } -    virtual Eigen::MatrixXd* getAttr() = 0; -    virtual Eigen::VectorXd* getPref() = 0; +    virtual int getDataSet(DataSet &out) = 0; +    virtual int getLabel(Labels &out) = 0;      virtual int open()=0;      virtual bool EOFile()=0;  }; diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 7937866..fd8f00d 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -8,17 +8,18 @@ class FileDP:public DataProvider  {  private:      std::string fname; +    bool eof;  public: -    FileDP(){}; -    FileDP(std::string fn):fname(fn){}; -    virtual Eigen::MatrixXd* getAttr(){ -        return new Eigen::MatrixXd(3,3); +    FileDP(std::string fn=""):fname(fn),eof(false){}; +    void setFname(std::string fn){fname=fn;}; +    virtual int getDataSet(DataSet &out){ +        return 0;      } -    virtual Eigen::VectorXd* getPref(){ -        return new Eigen::VectorXd(3); +    virtual int getLabel(Labels &out){ +        return 0;      };      virtual int open(){return 0;}; -    virtual bool EOFile(){return true;}; +    virtual bool EOFile(){return eof;};  };  #endif
\ No newline at end of file | 
