summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-04-10 20:39:00 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-04-10 20:39:00 +0800
commit705f3731f4c49a75e2824d16622ff853634335c7 (patch)
tree8c6a171615f27d0cb25484f72ccf1f84391eb9c3
parent01b523c7ce4eb5e692b0dcbec63efac0e8d1e2c7 (diff)
downloadranksvm-705f3731f4c49a75e2824d16622ff853634335c7.tar.gz
ranksvm-705f3731f4c49a75e2824d16622ff853634335c7.tar.bz2
ranksvm-705f3731f4c49a75e2824d16622ff853634335c7.zip
structuring input
-rw-r--r--main.cpp14
-rw-r--r--model/ranksvm.h5
-rw-r--r--model/ranksvmtn.cpp13
-rw-r--r--model/ranksvmtn.h4
-rw-r--r--tools/dataProvider.h27
-rw-r--r--tools/fileDataProvider.h25
6 files changed, 52 insertions, 36 deletions
diff --git a/main.cpp b/main.cpp
index 09821d2..1cb18b9 100644
--- a/main.cpp
+++ b/main.cpp
@@ -1,7 +1,7 @@
#include <iostream>
#include <Eigen/Dense>
#include <boost/program_options.hpp>
-#include <string>
+#include <list>
#include "tools/easylogging++.h"
#include "model/ranksvmtn.h"
#include "tools/fileDataProvider.h"
@@ -21,13 +21,11 @@ int train() {
// Generic training operations
dp.open();
- DataSet D;
- Labels L;
+ DataList D;
LOG(INFO)<<"Training started";
dp.getDataSet(D);
- dp.getLabel(L);
- rsvm->train(D,L);
+ rsvm->train(D);
LOG(INFO)<<"Training finished,saving model";
@@ -41,15 +39,15 @@ int predict() {
RSVM *rsvm;
rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());
FileDP dp(vm["feature"].as<std::string>().c_str());
- DataSet D;
- Labels L;
+ DataList D;
+ std::list<double> L;
while (!dp.EOFile())
{
dp.getDataSet(D);
rsvm->predict(D,L);
}
- Eigen::write_stream(std::cout, L);
+ // TODO output Eigen::write_stream(std::cout, L);
delete rsvm;
return 0;
}
diff --git a/model/ranksvm.h b/model/ranksvm.h
index b4ec7ce..e82b6be 100644
--- a/model/ranksvm.h
+++ b/model/ranksvm.h
@@ -3,6 +3,7 @@
#include<Eigen/Dense>
#include<string>
+#include<list>
#include"../tools/dataProvider.h"
#include "../tools/easylogging++.h"
@@ -24,8 +25,8 @@ protected:
SVMModel model;
int fsize;
public:
- virtual int train(DataSet &D, Labels &label)=0;
- virtual int predict(DataSet &D, Labels &res)=0;
+ virtual int train(DataList &D)=0;
+ virtual int predict(DataList &D,std::list<double> &res)=0;
// TODO Not sure how to construct this
// Possible solution: generate a nxn matrix each row contains the sorted list of ranker result.
int saveModel(const std::string fname);
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 3c9808b..821d231 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -1,5 +1,6 @@
#include "ranksvmtn.h"
#include<iostream>
+#include<list>
#include"../tools/matrixIO.h"
using namespace std;
@@ -78,8 +79,8 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect
return 0;
}
-int RSVMTN::train(DataSet &D, Labels &label){
- int iter = 0;
+int RSVMTN::train(DataList &D){
+ /*int iter = 0;
double C=1;
MatrixXd A;
@@ -121,13 +122,13 @@ int RSVMTN::train(DataSet &D, Labels &label){
// When dec is small enough
if (-step.dot(grad) < prec * obj)
break;
- }
+ }*/
return 0;
};
-int RSVMTN::predict(DataSet &D, Labels &res){
+int RSVMTN::predict(DataList &D, list<double> &res){
//TODO define A
- MatrixXd A;
- res = A*(D * model.weight);
+ for (list<DataEntry*>::iterator i=D.getData().begin(), end=D.getData().end();i!=end;++i)
+ res.push_back(((*i)->feature).dot(model.weight));
return 0;
}; \ No newline at end of file
diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h
index 703fee4..6ed6ad7 100644
--- a/model/ranksvmtn.h
+++ b/model/ranksvmtn.h
@@ -12,8 +12,8 @@ public:
{
return "TN";
};
- virtual int train(DataSet &D, Labels &label);
- virtual int predict(DataSet &D, Labels &res);
+ virtual int train(DataList &D);
+ virtual int predict(DataList &D,std::list<double> &res);
};
int cg_solve(const Eigen::MatrixXd &A, const Eigen::VectorXd &b, Eigen::VectorXd &x);
diff --git a/tools/dataProvider.h b/tools/dataProvider.h
index bff1f44..fbf554b 100644
--- a/tools/dataProvider.h
+++ b/tools/dataProvider.h
@@ -16,10 +16,6 @@
// Use -1 to indicate not yet labeled data
// -1s will be excluded from training
-typedef Eigen::MatrixXd DataSet;
-
-typedef Eigen::VectorXd Labels;
-
typedef struct DataEntry{
int qid;
double rank;
@@ -29,28 +25,31 @@ typedef struct DataEntry{
class DataList{
private:
int n;
- std::list<DataEntry> data;
+ std::list<DataEntry*> data;
public:
int getSize(){return data.size();}
- void addEntry(DataEntry d){data.push_front(d);}
+ void addEntry(DataEntry* d){data.push_front(d);}
void setfSize(int fsize){n=fsize;}
int getfSize(){return n;}
+ int clear(){
+ for (std::list<DataEntry*>::iterator i=data.begin(),end=data.end();i!=end;++i)
+ delete *i;
+ data.clear();
+ }
+ std::list<DataEntry*> getData(){
+ return data;
+ }
+ ~DataList(){
+ clear();
+ }
};
class DataProvider //Virtual base class for data input
{
protected:
- int size;
- int attrSize;
bool eof;
public:
DataProvider():eof(false){};
- int getSize(){
- return size;
- }
- int getAttrSize(){
- return attrSize;
- }
bool EOFile(){return eof;}
diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h
index 9ce78e6..6ccf28f 100644
--- a/tools/fileDataProvider.h
+++ b/tools/fileDataProvider.h
@@ -3,21 +3,38 @@
#include "dataProvider.h"
#include <string>
+#include <iostream>
+#include <fstream>
class FileDP:public DataProvider
{
private:
std::string fname;
+ std::ifstream fin;
public:
FileDP(std::string fn=""):fname(fn){};
void setFname(std::string fn){fname=fn;};
virtual int getDataSet(DataList &out){
+ DataEntry* e;
+ out.clear();
+ int fsize;
+ out.setfSize(fsize);
+ fin>>fsize;
+ while (!fin.eof()) {
+ e= new DataEntry;
+ fin>>e->rank;
+ fin>>e->qid;
+ e->feature.resize(fsize);
+ for (int i=0;i<fsize;++i) {
+ fin>>e->feature(i);
+ }
+ out.addEntry(e);
+ }
+ eof=true;
return 0;
}
- int getDataSet(DataSet &D) {return 0;}
- int getLabel(Labels &l) {return 0;}
- virtual int open(){eof=true;return 0;};
- virtual int close(){return 0;};
+ virtual int open(){fin.open(fname); eof=false;return 0;};
+ virtual int close(){fin.close();return 0;};
};
#endif \ No newline at end of file