diff options
Diffstat (limited to 'model/ranksvmtn.cpp')
-rw-r--r-- | model/ranksvmtn.cpp | 65 |
1 files changed, 55 insertions, 10 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 821d231..776d4db 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -79,13 +79,10 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect return 0; } -int RSVMTN::train(DataList &D){ - /*int iter = 0; +int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){ + int iter = 0; double C=1; - MatrixXd A; - // TODO Undefined - long n=D.rows(); LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n; MatrixXd sv=MatrixXd::Identity(n, n); @@ -94,7 +91,7 @@ int RSVMTN::train(DataList &D){ VectorXd pred(n); double obj,t; - pred=VectorXd::Ones(n) - (A*(D*model.weight)); + pred=VectorXd::Ones(n) - (A*(D*weight)); while (true) { @@ -106,7 +103,7 @@ int RSVMTN::train(DataList &D){ } // Generate support vector matrix sv & gradient - objfunc_linear(model.weight,A,C,pred,grad,obj,sv); + objfunc_linear(weight,A,C,pred,grad,obj,sv); step = grad*0; MatrixXd H = MatrixXd::Identity(grad.cols(),grad.cols()); @@ -117,12 +114,60 @@ int RSVMTN::train(DataList &D){ // Solve cg_solve(H,grad,step); // do line search - line_search(model.weight,D,A,step,pred,t); - model.weight=model.weight+step*t; + line_search(weight,D,A,step,pred,t); + weight=weight+step*t; // When dec is small enough if (-step.dot(grad) < prec * obj) break; - }*/ + } + return 0; +} + +int RSVMTN::train(DataList &D){ + MatrixXd Data(D.getSize(),D.getfSize()),A; + int i=0,j=0; + list<DataEntry*>::iterator iter,st,nx; + for (iter= D.getData().begin();i<D.getSize();++i,++iter) + for (j=0;j<D.getfSize();++j) + Data(i,j)=(*iter)->feature(j); + nx=st=iter= D.getData().begin(); + ++nx; + int cnt=0; + while (iter!=D.getData().end()) + { + if ((nx == D.getData().end())||(*iter)->qid!=(*nx)->qid) + { + list<DataEntry*>::iterator high,low=iter; + for (high=st;((*high)->rank)>0;++high) + for (low=iter;((*low)->rank)<0;--low) + ++cnt; + st = nx; + } + ++iter; + } + A.resize(cnt,D.getSize()); + nx=st=iter= D.getData().begin(); + ++nx; + cnt=i=j=0; + while (iter!=D.getData().end()) + { + if ((nx == D.getData().end())||(*iter)->qid!=(*nx)->qid) + { + int v1=j,v2; + list<DataEntry*>::iterator high,low=iter; + for (high=st;((*high)->rank)>0;++high,++v1) + for (low=iter,v2=i;((*low)->rank)<0;--low,--v2) { + A(cnt,v1) = 1; + A(cnt,v2) = -1; + ++cnt; + } + st = nx; + j=i+1; + } + ++i; + ++iter; + } + train_orig(fsize,Data,A,model.weight); return 0; }; |