diff options
| -rw-r--r-- | main.cpp | 4 | ||||
| -rw-r--r-- | model/ranksvm.cpp | 44 | ||||
| -rw-r--r-- | model/ranksvm.h | 19 | ||||
| -rw-r--r-- | model/ranksvmtn.h | 20 | ||||
| -rw-r--r-- | tools/dataProvider.h | 7 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 20 | 
6 files changed, 102 insertions, 12 deletions
@@ -1,5 +1,8 @@  #include <iostream>  #include <Eigen/Dense> +#include "tools/easylogging++.h" + +INITIALIZE_EASYLOGGINGPP  using Eigen::MatrixXd; @@ -10,5 +13,6 @@ int main()    m(1,0) = 2.5;    m(0,1) = -1;    m(1,1) = m(1,0) + m(0,1); +  LOG(FATAL) << "My first info log using default logger";    std::cout << m << std::endl;  } diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 38fb70c..b15d2ef 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -1 +1,43 @@ -#include"ranksvm.h"
\ No newline at end of file +#include"ranksvm.h" +#include"ranksvmtron.h" +#include<iostream> +#include<fstream> +#include<string> + +using namespace Eigen; + +int RSVM::saveModel(string fname){ + +    std::ofstream fout(fname); +    fout<<this->getName()<<endl; +    fout<<this->model; +    return 0; +} + +static RSVM* RSVM::loadModel(string fname){ +    std::ifstream fin(fname); +    std::string type; +    int fsize; +    fin>>type; +    fin>>fsize; + +    RSVM* rsvm; + +    // TODO multiplex type +    if (type=="TN") +        RSVM = new RSVMTN(); + +    rsvm->fsize=fsize; +    VectorXd model; +    fin>>model; +    rsvm->setModel(model); + +    return rsvm; +} + +int RSVM::setModel(Eigen::VectorXd model) { +    if (model.cols()!=fsize) +        LOG(FATAL) << "Feature size mismatch";; +    this->model=model; +    return 0; +}
\ No newline at end of file diff --git a/model/ranksvm.h b/model/ranksvm.h index ba79c48..8993b87 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -4,19 +4,22 @@  #include<Eigen/Dense>  #include<string>  #include"../tools/dataProvider.h" +#include "../tools/easylogging++.h"  class RSVM  //Virtual base class for all RSVM operations  {  protected: -    Eigen::VectorXd* model; +    Eigen::VectorXd model; +    int fsize;  public: -    virtual int train(DataProvider &D)=0; -    int test(); -    int saveModel(string fname); -    static RSVM loadModel(string fname); -    string getName(); -    Eigen::MatrixXd getModel(); -    Eigen::MatrixXd setModel(); +    virtual int train(DataProvider D)=0; +    virtual int predict(DataProvider D); +    int saveModel(std::string fname); +    static RSVM loadModel(std::string fname); +    virtual std::string getName()=0; +    Eigen::MatrixXd getModel(){ +        return model;}; +    void setModel(Eigen::VectorXd model);  };  #endif
\ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h new file mode 100644 index 0000000..2a8f524 --- /dev/null +++ b/model/ranksvmtn.h @@ -0,0 +1,20 @@ +#ifndef RSVMTN_H +#define RSVMTN_H + +// Truncated Newton method based RankSVM + +#include"ranksvm.h" + +class RSVMTN:public RSVM +{ +public: +    std::string getName() +    { +        return "TN"; +    }; + +    int train(DataProvider D){return 0;}; +    int predict(DataProvider D){return 0;}; +}; + +#endif
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index 70a87f0..d598f3f 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -2,10 +2,11 @@  #define DATAPROV_H  #include<Eigen/Dense> +#include "../tools/easylogging++.h"  class DataProvider  //Virtual base class for data input  { -private: +protected:      int size;      int attrSize;  public: @@ -17,9 +18,9 @@ public:      }      virtual Eigen::MatrixXd* getAttr() = 0; -    virtual Eigen::MatrixXd* getPref() = 0; +    virtual Eigen::VectorXd* getPref() = 0;      virtual int open(); -    virtual int parse(); +    virtual bool EOFile();  };  #endif
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h new file mode 100644 index 0000000..3cfb3f9 --- /dev/null +++ b/tools/fileDataProvider.h @@ -0,0 +1,20 @@ +#ifndef FDPROV_H +#define FDPROV_H + +#include "dataProvider.h" +#include <string> + +class FileDP:public DataProvider +{ +private: +    std::string fname; +public: +    FileDP(){}; +    FileDP(std::string fn):fname(fn){}; +    virtual Eigen::MatrixXd* getNextAttr(); +    virtual Eigen::VectorXd* getNextPref(); +    virtual int open(); +    virtual bool EOFile(); +}; + +#endif
\ No newline at end of file  | 
