summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/ranksvmtn.cpp30
1 files changed, 25 insertions, 5 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index fe29468..193a7d4 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -45,7 +45,7 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x)
int objfunc_linear(const VectorXd &w,const MatrixXd &A,const double C,VectorXd &pred,VectorXd &grad, double &obj,MatrixXd &sv)
{
pred = pred.cwiseMax(MatrixXd::Zero(pred.rows(),pred.cols()));
-// obj = (pred.cwiseProduct(pred)*(C/2)) + w.transpose()*w/2;
+ obj = (pred.cwiseProduct(pred)*C).sum()/2 + w.dot(w)/2;
grad = w - (((pred*C).transpose()*A)*w).transpose();
for (int i=0;i<pred.cols();++i)
if (pred(i)>0)
@@ -55,9 +55,26 @@ int objfunc_linear(const VectorXd &w,const MatrixXd &A,const double C,VectorXd &
return 0;
}
-// line search
-int line_search(const VectorXd &w,const double C,const VectorXd &step,VectorXd &pred,double &t)
+// line search using newton method
+int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const VectorXd &step,VectorXd &pred,double &t)
{
+ double wd=w.dot(step),dd=step.dot(step);
+ double g,h;
+ t = 0;
+ VectorXd Xd=A*(D*step);
+ while (1)
+ {
+ pred = pred - t*Xd;
+ g=wd+t*dd;
+ h=dd;
+ for (int i=0;i<pred.cols();++i)
+ if (pred(i)>0) {
+ g += pred(i)*Xd(i);
+ h += Xd(i)*Xd(i);
+ }
+ if (g*g/h<1e-10)
+ break;
+ }
return 0;
}
@@ -87,13 +104,16 @@ int RSVMTN::train(DataSet &D, Labels &label){
}
// Generate support vector matrix sv & gradient
- objfunc_linear(D,A,1,pred,grad,obj,sv);
+ objfunc_linear(model.weight,A,1,pred,grad,obj,sv);
+
+ // Solve
+ // do line search
+ line_search(model.weight,D,A,step,pred,t);
model.weight=model.weight+step*t;
// When dec is small enough
if (-step.dot(grad) < prec * obj)
break;
}
-
return 0;
};