summaryrefslogtreecommitdiff
path: root/model/ranksvmtn.cpp
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-04-09 00:00:18 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-04-09 00:00:18 +0800
commit8b373d39b893c786197876e8320f3508555c80c3 (patch)
tree0901089944519d432215cf0549791eacdab6c995 /model/ranksvmtn.cpp
parentc02a3601c42bc4655d98c0982e6a72c365feafb0 (diff)
downloadranksvm-8b373d39b893c786197876e8320f3508555c80c3.tar.gz
ranksvm-8b373d39b893c786197876e8320f3508555c80c3.tar.bz2
ranksvm-8b373d39b893c786197876e8320f3508555c80c3.zip
svm main construction complete
Diffstat (limited to 'model/ranksvmtn.cpp')
-rw-r--r--model/ranksvmtn.cpp16
1 files changed, 12 insertions, 4 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 193a7d4..559723c 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -13,7 +13,7 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x)
double alpha,beta,r_1,r_2;
int step=0;
VectorXd q;
- VectorXd res = b - A*x;
+ VectorXd res = b - A*x; // TODO Hessian product
VectorXd p = res;
while (1)
{
@@ -28,7 +28,7 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x)
p = res + p * beta;
}
- q = A*p;
+ q = A*p; // TODO Hessian product
alpha = r_1/p.dot(q);
x=x+p*alpha;
res=res-q*alpha;
@@ -80,11 +80,12 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect
int RSVMTN::train(DataSet &D, Labels &label){
int iter = 0;
+ double C=1;
MatrixXd A;
// TODO Undefined
- int n=D.rows();
+ long n=D.rows();
LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n;
MatrixXd sv=MatrixXd::Identity(n, n);
VectorXd grad(fsize);
@@ -104,9 +105,16 @@ int RSVMTN::train(DataSet &D, Labels &label){
}
// Generate support vector matrix sv & gradient
- objfunc_linear(model.weight,A,1,pred,grad,obj,sv);
+ objfunc_linear(model.weight,A,C,pred,grad,obj,sv);
+ step = grad*0;
+ MatrixXd H = MatrixXd::Identity(grad.cols(),grad.cols());
+ // Compute Hessian directly
+ for (int i=0;i<n;++i)
+ if (sv(i,i)>0)
+ H = H + 2*C*A.row(i).transpose()*A.row(i);
// Solve
+ cg_solve(H,grad,step);
// do line search
line_search(model.weight,D,A,step,pred,t);
model.weight=model.weight+step*t;