From 705f3731f4c49a75e2824d16622ff853634335c7 Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Fri, 10 Apr 2015 20:39:00 +0800 Subject: structuring input --- model/ranksvm.h | 5 +++-- model/ranksvmtn.cpp | 13 +++++++------ model/ranksvmtn.h | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) (limited to 'model') diff --git a/model/ranksvm.h b/model/ranksvm.h index b4ec7ce..e82b6be 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -3,6 +3,7 @@ #include #include +#include #include"../tools/dataProvider.h" #include "../tools/easylogging++.h" @@ -24,8 +25,8 @@ protected: SVMModel model; int fsize; public: - virtual int train(DataSet &D, Labels &label)=0; - virtual int predict(DataSet &D, Labels &res)=0; + virtual int train(DataList &D)=0; + virtual int predict(DataList &D,std::list &res)=0; // TODO Not sure how to construct this // Possible solution: generate a nxn matrix each row contains the sorted list of ranker result. int saveModel(const std::string fname); diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 3c9808b..821d231 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -1,5 +1,6 @@ #include "ranksvmtn.h" #include +#include #include"../tools/matrixIO.h" using namespace std; @@ -78,8 +79,8 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect return 0; } -int RSVMTN::train(DataSet &D, Labels &label){ - int iter = 0; +int RSVMTN::train(DataList &D){ + /*int iter = 0; double C=1; MatrixXd A; @@ -121,13 +122,13 @@ int RSVMTN::train(DataSet &D, Labels &label){ // When dec is small enough if (-step.dot(grad) < prec * obj) break; - } + }*/ return 0; }; -int RSVMTN::predict(DataSet &D, Labels &res){ +int RSVMTN::predict(DataList &D, list &res){ //TODO define A - MatrixXd A; - res = A*(D * model.weight); + for (list::iterator i=D.getData().begin(), end=D.getData().end();i!=end;++i) + res.push_back(((*i)->feature).dot(model.weight)); return 0; }; \ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 703fee4..6ed6ad7 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -12,8 +12,8 @@ public: { return "TN"; }; - virtual int train(DataSet &D, Labels &label); - virtual int predict(DataSet &D, Labels &res); + virtual int train(DataList &D); + virtual int predict(DataList &D,std::list &res); }; int cg_solve(const Eigen::MatrixXd &A, const Eigen::VectorXd &b, Eigen::VectorXd &x); -- cgit v1.2.3-70-g09d2