summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-04-11 09:16:57 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-04-11 09:16:57 +0800
commit5d57accc3e1bc2b89c1e7753f7dbb40f3c8f575a (patch)
tree86f1256e7d804f508557ebd33d7e3216e4637de3
parent705f3731f4c49a75e2824d16622ff853634335c7 (diff)
downloadranksvm-5d57accc3e1bc2b89c1e7753f7dbb40f3c8f575a.tar.gz
ranksvm-5d57accc3e1bc2b89c1e7753f7dbb40f3c8f575a.tar.bz2
ranksvm-5d57accc3e1bc2b89c1e7753f7dbb40f3c8f575a.zip
input processing complete
-rw-r--r--model/ranksvmtn.cpp65
1 files changed, 55 insertions, 10 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 821d231..776d4db 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -79,13 +79,10 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect
return 0;
}
-int RSVMTN::train(DataList &D){
- /*int iter = 0;
+int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){
+ int iter = 0;
double C=1;
- MatrixXd A;
- // TODO Undefined
-
long n=D.rows();
LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n;
MatrixXd sv=MatrixXd::Identity(n, n);
@@ -94,7 +91,7 @@ int RSVMTN::train(DataList &D){
VectorXd pred(n);
double obj,t;
- pred=VectorXd::Ones(n) - (A*(D*model.weight));
+ pred=VectorXd::Ones(n) - (A*(D*weight));
while (true)
{
@@ -106,7 +103,7 @@ int RSVMTN::train(DataList &D){
}
// Generate support vector matrix sv & gradient
- objfunc_linear(model.weight,A,C,pred,grad,obj,sv);
+ objfunc_linear(weight,A,C,pred,grad,obj,sv);
step = grad*0;
MatrixXd H = MatrixXd::Identity(grad.cols(),grad.cols());
@@ -117,12 +114,60 @@ int RSVMTN::train(DataList &D){
// Solve
cg_solve(H,grad,step);
// do line search
- line_search(model.weight,D,A,step,pred,t);
- model.weight=model.weight+step*t;
+ line_search(weight,D,A,step,pred,t);
+ weight=weight+step*t;
// When dec is small enough
if (-step.dot(grad) < prec * obj)
break;
- }*/
+ }
+ return 0;
+}
+
+int RSVMTN::train(DataList &D){
+ MatrixXd Data(D.getSize(),D.getfSize()),A;
+ int i=0,j=0;
+ list<DataEntry*>::iterator iter,st,nx;
+ for (iter= D.getData().begin();i<D.getSize();++i,++iter)
+ for (j=0;j<D.getfSize();++j)
+ Data(i,j)=(*iter)->feature(j);
+ nx=st=iter= D.getData().begin();
+ ++nx;
+ int cnt=0;
+ while (iter!=D.getData().end())
+ {
+ if ((nx == D.getData().end())||(*iter)->qid!=(*nx)->qid)
+ {
+ list<DataEntry*>::iterator high,low=iter;
+ for (high=st;((*high)->rank)>0;++high)
+ for (low=iter;((*low)->rank)<0;--low)
+ ++cnt;
+ st = nx;
+ }
+ ++iter;
+ }
+ A.resize(cnt,D.getSize());
+ nx=st=iter= D.getData().begin();
+ ++nx;
+ cnt=i=j=0;
+ while (iter!=D.getData().end())
+ {
+ if ((nx == D.getData().end())||(*iter)->qid!=(*nx)->qid)
+ {
+ int v1=j,v2;
+ list<DataEntry*>::iterator high,low=iter;
+ for (high=st;((*high)->rank)>0;++high,++v1)
+ for (low=iter,v2=i;((*low)->rank)<0;--low,--v2) {
+ A(cnt,v1) = 1;
+ A(cnt,v2) = -1;
+ ++cnt;
+ }
+ st = nx;
+ j=i+1;
+ }
+ ++i;
+ ++iter;
+ }
+ train_orig(fsize,Data,A,model.weight);
return 0;
};