diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2015-04-08 18:04:41 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2015-04-08 18:04:41 +0800 | 
| commit | c02a3601c42bc4655d98c0982e6a72c365feafb0 (patch) | |
| tree | 44fa7ecf19becfe900ad93d54e3be0f58ecd040f /model | |
| parent | 2aed1b11102196f3d839b2801a92a87243355725 (diff) | |
| download | ranksvm-c02a3601c42bc4655d98c0982e6a72c365feafb0.tar.gz ranksvm-c02a3601c42bc4655d98c0982e6a72c365feafb0.tar.bz2 ranksvm-c02a3601c42bc4655d98c0982e6a72c365feafb0.zip  | |
linesearch&objective function passed build
Diffstat (limited to 'model')
| -rw-r--r-- | model/ranksvmtn.cpp | 30 | 
1 files changed, 25 insertions, 5 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index fe29468..193a7d4 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -45,7 +45,7 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x)  int objfunc_linear(const VectorXd &w,const MatrixXd &A,const double C,VectorXd &pred,VectorXd &grad, double &obj,MatrixXd &sv)  {      pred = pred.cwiseMax(MatrixXd::Zero(pred.rows(),pred.cols())); -//    obj = (pred.cwiseProduct(pred)*(C/2)) + w.transpose()*w/2; +    obj = (pred.cwiseProduct(pred)*C).sum()/2 + w.dot(w)/2;      grad = w - (((pred*C).transpose()*A)*w).transpose();      for (int i=0;i<pred.cols();++i)          if (pred(i)>0) @@ -55,9 +55,26 @@ int objfunc_linear(const VectorXd &w,const MatrixXd &A,const double C,VectorXd &      return 0;  } -// line search -int line_search(const VectorXd &w,const double C,const VectorXd &step,VectorXd &pred,double &t) +// line search using newton method +int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const VectorXd &step,VectorXd &pred,double &t)  { +    double wd=w.dot(step),dd=step.dot(step); +    double g,h; +    t = 0; +    VectorXd Xd=A*(D*step); +    while (1) +    { +        pred = pred - t*Xd; +        g=wd+t*dd; +        h=dd; +        for (int i=0;i<pred.cols();++i) +            if (pred(i)>0) { +                g += pred(i)*Xd(i); +                h += Xd(i)*Xd(i); +            } +        if (g*g/h<1e-10) +            break; +    }      return 0;  } @@ -87,13 +104,16 @@ int RSVMTN::train(DataSet &D, Labels &label){          }          // Generate support vector matrix sv & gradient -        objfunc_linear(D,A,1,pred,grad,obj,sv); +        objfunc_linear(model.weight,A,1,pred,grad,obj,sv); + +        // Solve +        // do line search +        line_search(model.weight,D,A,step,pred,t);          model.weight=model.weight+step*t;          // When dec is small enough          if (-step.dot(grad) < prec * obj)              break;      } -      return 0;  };  | 
