summaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-04-10 20:39:00 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-04-10 20:39:00 +0800
commit705f3731f4c49a75e2824d16622ff853634335c7 (patch)
tree8c6a171615f27d0cb25484f72ccf1f84391eb9c3 /model
parent01b523c7ce4eb5e692b0dcbec63efac0e8d1e2c7 (diff)
downloadranksvm-705f3731f4c49a75e2824d16622ff853634335c7.tar.gz
ranksvm-705f3731f4c49a75e2824d16622ff853634335c7.tar.bz2
ranksvm-705f3731f4c49a75e2824d16622ff853634335c7.zip
structuring input
Diffstat (limited to 'model')
-rw-r--r--model/ranksvm.h5
-rw-r--r--model/ranksvmtn.cpp13
-rw-r--r--model/ranksvmtn.h4
3 files changed, 12 insertions, 10 deletions
diff --git a/model/ranksvm.h b/model/ranksvm.h
index b4ec7ce..e82b6be 100644
--- a/model/ranksvm.h
+++ b/model/ranksvm.h
@@ -3,6 +3,7 @@
#include<Eigen/Dense>
#include<string>
+#include<list>
#include"../tools/dataProvider.h"
#include "../tools/easylogging++.h"
@@ -24,8 +25,8 @@ protected:
SVMModel model;
int fsize;
public:
- virtual int train(DataSet &D, Labels &label)=0;
- virtual int predict(DataSet &D, Labels &res)=0;
+ virtual int train(DataList &D)=0;
+ virtual int predict(DataList &D,std::list<double> &res)=0;
// TODO Not sure how to construct this
// Possible solution: generate a nxn matrix each row contains the sorted list of ranker result.
int saveModel(const std::string fname);
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 3c9808b..821d231 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -1,5 +1,6 @@
#include "ranksvmtn.h"
#include<iostream>
+#include<list>
#include"../tools/matrixIO.h"
using namespace std;
@@ -78,8 +79,8 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect
return 0;
}
-int RSVMTN::train(DataSet &D, Labels &label){
- int iter = 0;
+int RSVMTN::train(DataList &D){
+ /*int iter = 0;
double C=1;
MatrixXd A;
@@ -121,13 +122,13 @@ int RSVMTN::train(DataSet &D, Labels &label){
// When dec is small enough
if (-step.dot(grad) < prec * obj)
break;
- }
+ }*/
return 0;
};
-int RSVMTN::predict(DataSet &D, Labels &res){
+int RSVMTN::predict(DataList &D, list<double> &res){
//TODO define A
- MatrixXd A;
- res = A*(D * model.weight);
+ for (list<DataEntry*>::iterator i=D.getData().begin(), end=D.getData().end();i!=end;++i)
+ res.push_back(((*i)->feature).dot(model.weight));
return 0;
}; \ No newline at end of file
diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h
index 703fee4..6ed6ad7 100644
--- a/model/ranksvmtn.h
+++ b/model/ranksvmtn.h
@@ -12,8 +12,8 @@ public:
{
return "TN";
};
- virtual int train(DataSet &D, Labels &label);
- virtual int predict(DataSet &D, Labels &res);
+ virtual int train(DataList &D);
+ virtual int predict(DataList &D,std::list<double> &res);
};
int cg_solve(const Eigen::MatrixXd &A, const Eigen::VectorXd &b, Eigen::VectorXd &x);