diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-04-08 18:04:41 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-04-08 18:04:41 +0800 |
commit | c02a3601c42bc4655d98c0982e6a72c365feafb0 (patch) | |
tree | 44fa7ecf19becfe900ad93d54e3be0f58ecd040f | |
parent | 2aed1b11102196f3d839b2801a92a87243355725 (diff) | |
download | ranksvm-c02a3601c42bc4655d98c0982e6a72c365feafb0.tar.gz ranksvm-c02a3601c42bc4655d98c0982e6a72c365feafb0.tar.bz2 ranksvm-c02a3601c42bc4655d98c0982e6a72c365feafb0.zip |
linesearch&objective function passed build
-rw-r--r-- | model/ranksvmtn.cpp | 30 |
1 files changed, 25 insertions, 5 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index fe29468..193a7d4 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -45,7 +45,7 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x) int objfunc_linear(const VectorXd &w,const MatrixXd &A,const double C,VectorXd &pred,VectorXd &grad, double &obj,MatrixXd &sv) { pred = pred.cwiseMax(MatrixXd::Zero(pred.rows(),pred.cols())); -// obj = (pred.cwiseProduct(pred)*(C/2)) + w.transpose()*w/2; + obj = (pred.cwiseProduct(pred)*C).sum()/2 + w.dot(w)/2; grad = w - (((pred*C).transpose()*A)*w).transpose(); for (int i=0;i<pred.cols();++i) if (pred(i)>0) @@ -55,9 +55,26 @@ int objfunc_linear(const VectorXd &w,const MatrixXd &A,const double C,VectorXd & return 0; } -// line search -int line_search(const VectorXd &w,const double C,const VectorXd &step,VectorXd &pred,double &t) +// line search using newton method +int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const VectorXd &step,VectorXd &pred,double &t) { + double wd=w.dot(step),dd=step.dot(step); + double g,h; + t = 0; + VectorXd Xd=A*(D*step); + while (1) + { + pred = pred - t*Xd; + g=wd+t*dd; + h=dd; + for (int i=0;i<pred.cols();++i) + if (pred(i)>0) { + g += pred(i)*Xd(i); + h += Xd(i)*Xd(i); + } + if (g*g/h<1e-10) + break; + } return 0; } @@ -87,13 +104,16 @@ int RSVMTN::train(DataSet &D, Labels &label){ } // Generate support vector matrix sv & gradient - objfunc_linear(D,A,1,pred,grad,obj,sv); + objfunc_linear(model.weight,A,1,pred,grad,obj,sv); + + // Solve + // do line search + line_search(model.weight,D,A,step,pred,t); model.weight=model.weight+step*t; // When dec is small enough if (-step.dot(grad) < prec * obj) break; } - return 0; }; |