diff options
| -rw-r--r-- | model/ranksvm.cpp | 6 | ||||
| -rw-r--r-- | model/ranksvm.h | 4 | ||||
| -rw-r--r-- | model/ranksvmtn.cpp | 6 | ||||
| -rw-r--r-- | train.cpp | 27 | 
4 files changed, 26 insertions, 17 deletions
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 2f366d6..368e16d 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -15,11 +15,11 @@ int maxiter = 1; // max iteration count  double prec=1e-10; // precision  // conjugate gradient  double cg_prec=1e-10; // precision -int cg_maxiter = 2000; +int cg_maxiter = 100;  int ls_maxiter = 20;  // line search -double line_prec=1e-10; // precision -double line_turb=1e-15; // purturbation +double ls_prec=1e-10; // precision +double ls_turb=1e-15; // purturbation  int RSVM::saveModel(const string fname){ diff --git a/model/ranksvm.h b/model/ranksvm.h index 41de0f4..a17e3c9 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -46,7 +46,7 @@ extern double cg_prec; // precision  extern int cg_maxiter; // not worth having a large number  extern int ls_maxiter;  // line search -extern double line_prec; // precision -extern double line_turb; // purturbation +extern double ls_prec; // precision +extern double ls_turb; // perturbation  #endif
\ No newline at end of file diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index b599a03..4572484 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -192,12 +192,12 @@ int line_search(const VectorXd &w,RidList &D,const VectorXd &corr,const VectorXd          g = grad.dot(step);          cal_Hs(D,rank,corr,alpha,step,Hs);          h = Hs.dot(step); -        g=g+line_turb; -        h = h+line_turb; +        g=g+ls_turb; +        h = h+ls_turb;          t=t-g/h;          ++iter;          LOG(INFO) << "line search iter "<<iter<<", prec:"<<g*g/h; -        if (g*g/h<line_prec) +        if (g*g/h<ls_prec)              break;          if (iter >= ls_maxiter)          { @@ -27,7 +27,9 @@ int train(DataProvider &dp) {      LOG(INFO)<<"Training started";      dp.getAllDataSet(D);      LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features"; -    LOG(INFO)<<"C: "<<C; +    LOG(INFO)<<"C: "<<C<<" ,iter: "<<maxiter<<" ,prec: "<<prec; +    LOG(INFO)<<"cg_maxiter: "<<cg_maxiter<<" ,cg_prec:"<<cg_prec<<" ,ls_maxiter: "<<ls_maxiter<<" ,ls_prec: "<<ls_prec; +      rsvm->train(D);      LOG(INFO)<<"Training finished,saving model"; @@ -109,7 +111,13 @@ int main(int argc, char **argv) {              ("model,m", po::value<string>(), "set input model file")              ("output,o", po::value<string>(), "set output model/prediction file")              ("feature,i", po::value<string>(), "set input feature file") -            ("c,c",po::value<double>(),"trades margin size against training error"); +            ("c,c",po::value<double>(),"trades margin size against training error") +            ("iter",po::value<int>(),"iter main") +            ("prec",po::value<double>(),"prec main") +            ("cg_iter",po::value<int>(),"iter conjugate gradient") +            ("cg_prec",po::value<double>(),"prec conjugate gradient") +            ("ls_iter",po::value<int>(),"iter line search") +            ("ls_prec",po::value<double>(),"prec line search");      // Parsing program options      po::store(po::parse_command_line(argc, argv, desc), vm); @@ -127,14 +135,15 @@ int main(int argc, char **argv) {      el::Loggers::reconfigureLogger("default", defaultConf);      mainFunc mainf; -    if (vm.count("single")) -        RidList::single=true; -    else -        RidList::single=false; +    RidList::single=vm.count("single")>0;      if (vm.count("train")) { -        if (vm.count("c")) { -            C=vm["c"].as<double>(); -        } +        if (vm.count("c")) { C=vm["c"].as<double>(); } +        if (vm.count("iter")) { maxiter=vm["iter"].as<int>(); } +        if (vm.count("prec")) { prec=vm["prec"].as<double>(); } +        if (vm.count("cg_iter")) { cg_maxiter=vm["cg_iter"].as<int>(); } +        if (vm.count("cg_prec")) { cg_prec=vm["cg_prec"].as<double>(); } +        if (vm.count("ls_iter")) { ls_maxiter=vm["ls_iter"].as<int>(); } +        if (vm.count("ls_prec")) { ls_prec=vm["ls_prec"].as<double>(); }          mainf = &train;      }      else if (vm.count("validate")||vm.count("predict")) {  | 
