diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-03-15 23:43:26 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-03-15 23:43:26 +0800 |
commit | a698a1e2cceb6947e244a0de01424894dc37028a (patch) | |
tree | 80189049afcfdf103b58d3013e8bad9277df2539 | |
parent | 68b7998f12e7e3c9b257a9d2a4c14e8d32a048b9 (diff) | |
download | ranksvm-a698a1e2cceb6947e244a0de01424894dc37028a.tar.gz ranksvm-a698a1e2cceb6947e244a0de01424894dc37028a.tar.bz2 ranksvm-a698a1e2cceb6947e244a0de01424894dc37028a.zip |
inp
-rw-r--r-- | model/ranksvmtn.cpp | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 59980b4..6a7057d 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -7,9 +7,16 @@ const int maxiter = 10; const double prec=1e-3; // Calculate objfunc gradient & support vectors -int objfunc_linear(const VectorXd &w,const double C,const VectorXd &pred, double &obj,MatrixXd &sv) +int objfunc_linear(const VectorXd &w,const double C,const VectorXd &pred,const VectorXd &grad, double &obj,MatrixXd &sv) { - + pred = pred.cwiseMax(Matrix::Zero(pred.rows(),pred.cols())); + obj = (pred.cwiseProduct(pred)*(C/2)) + w.transpose()*w/2; + grad = w - (((pred*C).transpose()*A)*w).transpose(); + for (int i=0;i<pred.cols();++i) + if (pred(i)>0) + sv(i,i)=1; + else + sv(i,i)=0; } // line search @@ -43,6 +50,7 @@ int RSVMTN::train(DataSet &D, Labels &label){ } // Generate support vector matrix sv & gradient + objfunc_linear(D,1,pred,grad,obj,sv); model.weight=model.weight+step*t; // When dec is small enough if (-step.dot(grad) < prec * obj) |