diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-04-11 19:32:30 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-04-11 19:32:30 +0800 |
commit | ddfbe260e76c4c5cf5e601e21d39c5f5cc67b2c6 (patch) | |
tree | 7ab82ebfe36f263f6b4286ed312cabc2ce759ba6 /model | |
parent | 85a4962556b67d1cc0668ecb2fbb03b3b4dd6e7e (diff) | |
download | ranksvm-ddfbe260e76c4c5cf5e601e21d39c5f5cc67b2c6.tar.gz ranksvm-ddfbe260e76c4c5cf5e601e21d39c5f5cc67b2c6.tar.bz2 ranksvm-ddfbe260e76c4c5cf5e601e21d39c5f5cc67b2c6.zip |
A substituted
Diffstat (limited to 'model')
-rw-r--r-- | model/ranksvmtn.cpp | 47 |
1 files changed, 28 insertions, 19 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 959ea7d..105c3fe 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -39,22 +39,30 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x) } // Calculate objfunc gradient & support vectors -int objfunc_linear(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const double C,VectorXd &pred,VectorXd &grad, double &obj) +int objfunc_linear(const VectorXd &w,const MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const double C,VectorXd &pred,VectorXd &grad, double &obj) { for (int i=0;i<pred.rows();++i) pred(i)=pred(i)>0?pred(i):0; obj = (pred.cwiseProduct(pred)*C).sum()/2 + w.dot(w)/2; - grad = w - (((pred*C).transpose()*A)*D).transpose(); + VectorXd pA = VectorXd::Zero(D.rows()); + for (int i=0;i<A1.size();++i) { + pA(A1[i]) = pA(A1[i]) + pred(i); + pA(A2[i]) = pA(A2[i]) - pred(i); + } + grad = w - (pA.transpose()*D).transpose(); return 0; } // line search using newton method -int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const VectorXd &step,VectorXd &pred,double &t) +int line_search(const VectorXd &w,const MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const VectorXd &step,VectorXd &pred,double &t) { double wd=w.dot(step),dd=step.dot(step); + VectorXd Dd = D*step; + VectorXd Xd = VectorXd::Zero(A1.size()); + for (int i=0;i<A1.size();++i) + Xd(i) = Dd(A1[i])-Dd(A2[i]); double g,h; t = 0; - VectorXd Xd=A*(D*step); VectorXd pred2; while (1) { @@ -69,8 +77,6 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect g=g+1e-12; h=h+1e-12; t=t-g/h; - cout<<g<<":"<<h<<endl; - cin.get(); if (g*g/h<1e-10) break; } @@ -78,17 +84,20 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect return 0; } -int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){ +int train_orig(int fsize, MatrixXd &D,vector<int> &A1,vector<int> &A2,VectorXd &weight){ int iter = 0; - long n=A.rows(); - LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Relation size:" << A.rows(); + long n=A1.size(); + LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Relation size:" << A1.size(); VectorXd grad(fsize); VectorXd step(fsize); VectorXd pred(n); double obj,t; - pred=VectorXd::Ones(n) - (A*(D*weight)); + VectorXd dw = D*weight; + pred=VectorXd::Zero(n); + for (int i=0;i<n;++i) + pred(i) = 1 - dw(A1[i])+dw(A2[i]); while (true) { iter+=1; @@ -99,20 +108,21 @@ int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){ } // Generate support vector matrix sv & gradient - objfunc_linear(weight,D,A,C,pred,grad,obj); + objfunc_linear(weight,D,A1,A2,C,pred,grad,obj); step = grad*0; MatrixXd H = MatrixXd::Identity(grad.rows(),grad.rows()); // Compute Hessian directly for (int i=0;i<n;++i) if (pred(i)>0) { - VectorXd v = A.row(i)*D; + VectorXd v = D.row(A1[i])-D.row(A2[i]); H = H + C * (v * v.transpose()); } // Solve //cout<<obj<<endl; cg_solve(H,grad,step); // do line search - line_search(weight,D,A,step,pred,t); + + line_search(weight,D,A1,A2,step,pred,t); weight=weight+step*t; int sv=0; for (int i=0;i<n;++i) @@ -127,7 +137,8 @@ int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){ } int RSVMTN::train(DataList &D){ - MatrixXd Data(D.getSize(),D.getfSize()),A; + MatrixXd Data(D.getSize(),D.getfSize()); + vector<int> A1,A2; int i,j; LOG(INFO)<<"Processing input"; for (i=0;i<D.getSize();++i) { @@ -148,7 +159,6 @@ int RSVMTN::train(DataList &D){ } ++i; } - A.resize(cnt,D.getSize()); cnt=i=j=0; while (i<D.getSize()) { @@ -157,20 +167,19 @@ int RSVMTN::train(DataList &D){ int v1=j,v2; for (v1=j;(D.getData()[v1]->rank)>0;++v1) for (v2=i;(D.getData()[v2]->rank)<0;--v2) { - A(cnt,v1) = 1; - A(cnt,v2) = -1; + A1.push_back(v1); + A2.push_back(v2); ++cnt; } j = i+1; } ++i; } - train_orig(fsize,Data,A,model.weight); + train_orig(fsize,Data,A1,A2,model.weight); return 0; }; int RSVMTN::predict(DataList &D, vector<double> &res){ - //TODO define A res.clear(); for (int i=0;i<D.getSize();++i) res.push_back(((D.getData()[i])->feature).dot(model.weight)); |