summaryrefslogtreecommitdiff
path: root/model/ranksvmtn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'model/ranksvmtn.cpp')
-rw-r--r--model/ranksvmtn.cpp12
1 files changed, 10 insertions, 2 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 59980b4..6a7057d 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -7,9 +7,16 @@ const int maxiter = 10;
const double prec=1e-3;
// Calculate objfunc gradient & support vectors
-int objfunc_linear(const VectorXd &w,const double C,const VectorXd &pred, double &obj,MatrixXd &sv)
+int objfunc_linear(const VectorXd &w,const double C,const VectorXd &pred,const VectorXd &grad, double &obj,MatrixXd &sv)
{
-
+ pred = pred.cwiseMax(Matrix::Zero(pred.rows(),pred.cols()));
+ obj = (pred.cwiseProduct(pred)*(C/2)) + w.transpose()*w/2;
+ grad = w - (((pred*C).transpose()*A)*w).transpose();
+ for (int i=0;i<pred.cols();++i)
+ if (pred(i)>0)
+ sv(i,i)=1;
+ else
+ sv(i,i)=0;
}
// line search
@@ -43,6 +50,7 @@ int RSVMTN::train(DataSet &D, Labels &label){
}
// Generate support vector matrix sv & gradient
+ objfunc_linear(D,1,pred,grad,obj,sv);
model.weight=model.weight+step*t;
// When dec is small enough
if (-step.dot(grad) < prec * obj)