diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-04-09 00:00:18 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-04-09 00:00:18 +0800 |
commit | 8b373d39b893c786197876e8320f3508555c80c3 (patch) | |
tree | 0901089944519d432215cf0549791eacdab6c995 | |
parent | c02a3601c42bc4655d98c0982e6a72c365feafb0 (diff) | |
download | ranksvm-8b373d39b893c786197876e8320f3508555c80c3.tar.gz ranksvm-8b373d39b893c786197876e8320f3508555c80c3.tar.bz2 ranksvm-8b373d39b893c786197876e8320f3508555c80c3.zip |
svm main construction complete
-rw-r--r-- | model/ranksvmtn.cpp | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 193a7d4..559723c 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -13,7 +13,7 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x) double alpha,beta,r_1,r_2; int step=0; VectorXd q; - VectorXd res = b - A*x; + VectorXd res = b - A*x; // TODO Hessian product VectorXd p = res; while (1) { @@ -28,7 +28,7 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x) p = res + p * beta; } - q = A*p; + q = A*p; // TODO Hessian product alpha = r_1/p.dot(q); x=x+p*alpha; res=res-q*alpha; @@ -80,11 +80,12 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect int RSVMTN::train(DataSet &D, Labels &label){ int iter = 0; + double C=1; MatrixXd A; // TODO Undefined - int n=D.rows(); + long n=D.rows(); LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n; MatrixXd sv=MatrixXd::Identity(n, n); VectorXd grad(fsize); @@ -104,9 +105,16 @@ int RSVMTN::train(DataSet &D, Labels &label){ } // Generate support vector matrix sv & gradient - objfunc_linear(model.weight,A,1,pred,grad,obj,sv); + objfunc_linear(model.weight,A,C,pred,grad,obj,sv); + step = grad*0; + MatrixXd H = MatrixXd::Identity(grad.cols(),grad.cols()); + // Compute Hessian directly + for (int i=0;i<n;++i) + if (sv(i,i)>0) + H = H + 2*C*A.row(i).transpose()*A.row(i); // Solve + cg_solve(H,grad,step); // do line search line_search(model.weight,D,A,step,pred,t); model.weight=model.weight+step*t; |