diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2015-04-10 20:39:00 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2015-04-10 20:39:00 +0800 | 
| commit | 705f3731f4c49a75e2824d16622ff853634335c7 (patch) | |
| tree | 8c6a171615f27d0cb25484f72ccf1f84391eb9c3 | |
| parent | 01b523c7ce4eb5e692b0dcbec63efac0e8d1e2c7 (diff) | |
| download | ranksvm-705f3731f4c49a75e2824d16622ff853634335c7.tar.gz ranksvm-705f3731f4c49a75e2824d16622ff853634335c7.tar.bz2 ranksvm-705f3731f4c49a75e2824d16622ff853634335c7.zip  | |
structuring input
| -rw-r--r-- | main.cpp | 14 | ||||
| -rw-r--r-- | model/ranksvm.h | 5 | ||||
| -rw-r--r-- | model/ranksvmtn.cpp | 13 | ||||
| -rw-r--r-- | model/ranksvmtn.h | 4 | ||||
| -rw-r--r-- | tools/dataProvider.h | 27 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 25 | 
6 files changed, 52 insertions, 36 deletions
@@ -1,7 +1,7 @@  #include <iostream>  #include <Eigen/Dense>  #include <boost/program_options.hpp> -#include <string> +#include <list>  #include "tools/easylogging++.h"  #include "model/ranksvmtn.h"  #include "tools/fileDataProvider.h" @@ -21,13 +21,11 @@ int train() {      // Generic training operations      dp.open(); -    DataSet D; -    Labels L; +    DataList D;      LOG(INFO)<<"Training started";      dp.getDataSet(D); -    dp.getLabel(L); -    rsvm->train(D,L); +    rsvm->train(D);      LOG(INFO)<<"Training finished,saving model"; @@ -41,15 +39,15 @@ int predict() {      RSVM *rsvm;      rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());      FileDP dp(vm["feature"].as<std::string>().c_str()); -    DataSet D; -    Labels L; +    DataList D; +    std::list<double> L;      while (!dp.EOFile())      {          dp.getDataSet(D);          rsvm->predict(D,L);      } -    Eigen::write_stream(std::cout, L); +    // TODO output Eigen::write_stream(std::cout, L);      delete rsvm;      return 0;  } diff --git a/model/ranksvm.h b/model/ranksvm.h index b4ec7ce..e82b6be 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -3,6 +3,7 @@  #include<Eigen/Dense>  #include<string> +#include<list>  #include"../tools/dataProvider.h"  #include "../tools/easylogging++.h" @@ -24,8 +25,8 @@ protected:      SVMModel model;      int fsize;  public: -    virtual int train(DataSet &D, Labels &label)=0; -    virtual int predict(DataSet &D, Labels &res)=0; +    virtual int train(DataList &D)=0; +    virtual int predict(DataList &D,std::list<double> &res)=0;      // TODO Not sure how to construct this      //  Possible solution: generate a nxn matrix each row contains the sorted list of ranker result.      int saveModel(const std::string fname); diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 3c9808b..821d231 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -1,5 +1,6 @@  #include "ranksvmtn.h"  #include<iostream> +#include<list>  #include"../tools/matrixIO.h"  using namespace std; @@ -78,8 +79,8 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect      return 0;  } -int RSVMTN::train(DataSet &D, Labels &label){ -    int iter = 0; +int RSVMTN::train(DataList &D){ +    /*int iter = 0;      double C=1;      MatrixXd A; @@ -121,13 +122,13 @@ int RSVMTN::train(DataSet &D, Labels &label){          // When dec is small enough          if (-step.dot(grad) < prec * obj)              break; -    } +    }*/      return 0;  }; -int RSVMTN::predict(DataSet &D, Labels &res){ +int RSVMTN::predict(DataList &D, list<double> &res){      //TODO define A -    MatrixXd A; -    res = A*(D * model.weight); +    for (list<DataEntry*>::iterator i=D.getData().begin(), end=D.getData().end();i!=end;++i) +        res.push_back(((*i)->feature).dot(model.weight));      return 0;  };
\ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 703fee4..6ed6ad7 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -12,8 +12,8 @@ public:      {          return "TN";      }; -    virtual int train(DataSet &D, Labels &label); -    virtual int predict(DataSet &D, Labels &res); +    virtual int train(DataList &D); +    virtual int predict(DataList &D,std::list<double> &res);  };  int cg_solve(const Eigen::MatrixXd &A, const Eigen::VectorXd &b, Eigen::VectorXd &x); diff --git a/tools/dataProvider.h b/tools/dataProvider.h index bff1f44..fbf554b 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -16,10 +16,6 @@  //      Use -1 to indicate not yet labeled data  //      -1s will be excluded from training -typedef Eigen::MatrixXd DataSet; - -typedef Eigen::VectorXd Labels; -  typedef struct DataEntry{      int qid;      double rank; @@ -29,28 +25,31 @@ typedef struct DataEntry{  class DataList{  private:      int n; -    std::list<DataEntry> data; +    std::list<DataEntry*> data;  public:      int getSize(){return data.size();} -    void addEntry(DataEntry d){data.push_front(d);} +    void addEntry(DataEntry* d){data.push_front(d);}      void setfSize(int fsize){n=fsize;}      int getfSize(){return n;} +    int clear(){ +        for (std::list<DataEntry*>::iterator i=data.begin(),end=data.end();i!=end;++i) +            delete *i; +        data.clear(); +    } +    std::list<DataEntry*> getData(){ +        return data; +    } +    ~DataList(){ +        clear(); +    }  };  class DataProvider  //Virtual base class for data input  {  protected: -    int size; -    int attrSize;      bool eof;  public:      DataProvider():eof(false){}; -    int getSize(){ -        return size; -    } -    int getAttrSize(){ -        return attrSize; -    }      bool EOFile(){return eof;} diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 9ce78e6..6ccf28f 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -3,21 +3,38 @@  #include "dataProvider.h"  #include <string> +#include <iostream> +#include <fstream>  class FileDP:public DataProvider  {  private:      std::string fname; +    std::ifstream fin;  public:      FileDP(std::string fn=""):fname(fn){};      void setFname(std::string fn){fname=fn;};      virtual int getDataSet(DataList &out){ +        DataEntry* e; +        out.clear(); +        int fsize; +        out.setfSize(fsize); +        fin>>fsize; +        while (!fin.eof()) { +            e= new DataEntry; +            fin>>e->rank; +            fin>>e->qid; +            e->feature.resize(fsize); +            for (int i=0;i<fsize;++i) { +                fin>>e->feature(i); +            } +            out.addEntry(e); +        } +        eof=true;          return 0;      } -    int getDataSet(DataSet &D) {return 0;} -    int getLabel(Labels &l) {return 0;} -    virtual int open(){eof=true;return 0;}; -    virtual int close(){return 0;}; +    virtual int open(){fin.open(fname); eof=false;return 0;}; +    virtual int close(){fin.close();return 0;};  };  #endif
\ No newline at end of file  | 
