From 85a4962556b67d1cc0668ecb2fbb03b3b4dd6e7e Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Sat, 11 Apr 2015 17:46:19 +0800 Subject: completed & tested, train & predict --- model/ranksvmtn.cpp | 117 ++++++++++++++++++++++++++-------------------------- 1 file changed, 58 insertions(+), 59 deletions(-) (limited to 'model/ranksvmtn.cpp') diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 776d4db..959ea7d 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -7,7 +7,8 @@ using namespace std; using namespace Eigen; const int maxiter = 10; -const double prec=1e-3; +const double prec=1e-4; +const double C=1; int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x) { @@ -20,9 +21,7 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x) { // Non preconditioned version r_1 = res.dot(res); - cout<0?pred(i):0; obj = (pred.cwiseProduct(pred)*C).sum()/2 + w.dot(w)/2; - grad = w - (((pred*C).transpose()*A)*w).transpose(); - for (int i=0;i0) - sv(i,i)=1; - else - sv(i,i)=0; + grad = w - (((pred*C).transpose()*A)*D).transpose(); return 0; } @@ -63,36 +55,40 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect double g,h; t = 0; VectorXd Xd=A*(D*step); + VectorXd pred2; while (1) { - pred = pred - t*Xd; + pred2 = pred - t*Xd; g=wd+t*dd; h=dd; - for (int i=0;i0) { - g += pred(i)*Xd(i); - h += Xd(i)*Xd(i); + for (int i=0;i0) { + g -= C*pred2(i)*Xd(i); + h += C*Xd(i)*Xd(i); } + g=g+1e-12; + h=h+1e-12; + t=t-g/h; + cout<0) - H = H + 2*C*A.row(i).transpose()*A.row(i); + if (pred(i)>0) { + VectorXd v = A.row(i)*D; + H = H + C * (v * v.transpose()); + } // Solve + //cout<0) + ++sv; // When dec is small enough - if (-step.dot(grad) < prec * obj) + LOG(INFO)<<"Iter: "<::iterator iter,st,nx; - for (iter= D.getData().begin();ifeature(j); - nx=st=iter= D.getData().begin(); - ++nx; + int i,j; + LOG(INFO)<<"Processing input"; + for (i=0;ifeature(j); + } int cnt=0; - while (iter!=D.getData().end()) + i=j=0; + while (iqid!=(*nx)->qid) + if ((i+1 == D.getSize())|| D.getData()[i]->qid!=D.getData()[i+1]->qid) { - list::iterator high,low=iter; - for (high=st;((*high)->rank)>0;++high) - for (low=iter;((*low)->rank)<0;--low) - ++cnt; - st = nx; + int high=j; + while (D.getData()[high]->rank>0) + ++high; + cnt += (high-j)*(i-high+1); + j = i+1; } - ++iter; + ++i; } A.resize(cnt,D.getSize()); - nx=st=iter= D.getData().begin(); - ++nx; cnt=i=j=0; - while (iter!=D.getData().end()) + while (iqid!=(*nx)->qid) + if ((i+1 == D.getSize())|| D.getData()[i]->qid!=D.getData()[i+1]->qid) { int v1=j,v2; - list::iterator high,low=iter; - for (high=st;((*high)->rank)>0;++high,++v1) - for (low=iter,v2=i;((*low)->rank)<0;--low,--v2) { + for (v1=j;(D.getData()[v1]->rank)>0;++v1) + for (v2=i;(D.getData()[v2]->rank)<0;--v2) { A(cnt,v1) = 1; A(cnt,v2) = -1; ++cnt; } - st = nx; - j=i+1; + j = i+1; } ++i; - ++iter; } train_orig(fsize,Data,A,model.weight); return 0; }; -int RSVMTN::predict(DataList &D, list &res){ +int RSVMTN::predict(DataList &D, vector &res){ //TODO define A - for (list::iterator i=D.getData().begin(), end=D.getData().end();i!=end;++i) - res.push_back(((*i)->feature).dot(model.weight)); + res.clear(); + for (int i=0;ifeature).dot(model.weight)); return 0; }; \ No newline at end of file -- cgit v1.2.3-70-g09d2