diff options
| -rw-r--r-- | model/ranksvm.cpp | 2 | ||||
| -rw-r--r-- | model/ranksvm.h | 2 | ||||
| -rw-r--r-- | model/ranksvmtn.cpp | 29 | ||||
| -rw-r--r-- | tools/fileDataProvider.cpp | 6 | ||||
| -rw-r--r-- | train.cpp | 7 | 
5 files changed, 33 insertions, 13 deletions
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 7ee72ac..dd2d9e4 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -8,6 +8,8 @@  using namespace Eigen;  using namespace std; +double C=1e-4; // Compensating & scaling +  int RSVM::saveModel(const string fname){      std::ofstream fout(fname.c_str()); diff --git a/model/ranksvm.h b/model/ranksvm.h index aa5e1ca..26a3f6d 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -37,4 +37,6 @@ public:      int setModel(const SVMModel &model);  }; +extern double C; +  #endif
\ No newline at end of file diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 01d9851..0e41861 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -6,13 +6,13 @@  using namespace std;  using namespace Eigen; -const double C=1e-4; // Compensating & scaling  // Main terminating criteria -const int maxiter = 20; // max iteration count +const int maxiter = 40; // max iteration count  const double prec=1e-10; // precision  // conjugate gradient  const double cg_prec=1e-10; // precision -const int cg_maxiter = 30; +const int cg_maxiter = 5; // not worth having a large number +const int ls_maxiter = 10;  // line search  const double line_prec=1e-10; // precision  const double line_turb=1e-15; // purturbation @@ -54,11 +54,13 @@ int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,cons      {          // Non preconditioned version          r_1 = res.dot(res); +        if (iter) +            LOG(INFO) << "CG iter "<<iter<<", r:"<<r_1;          if (r_1<cg_prec) // Terminate condition              break; -        if (iter > cg_maxiter) +        if (iter >= cg_maxiter)          { -            LOG(INFO) << "CG forced termination by maxiter, r:"<<r_1; +            LOG(INFO) << "CG forced termination by maxiter";              break;          }          if (iter){ @@ -162,10 +164,11 @@ int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const v          g=g+line_turb;          h = h+line_turb;          t=t-g/h; +        ++iter; +        LOG(INFO) << "line search iter "<<iter<<", prec:"<<g*g/h;          if (g*g/h<line_prec)              break; -        ++iter; -        if (iter > cg_maxiter) +        if (iter >= ls_maxiter)          {              LOG(INFO) << "line search forced termination by maxiter, prec:"<<g*g/h;              break; @@ -189,7 +192,6 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A      VectorXd alpha,beta;      while (true)      { -        iter+=1;          if (iter> maxiter)          {              LOG(INFO)<< "Maxiter reached"; @@ -208,8 +210,15 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A          line_search(weight,D,corr,A1,A2,step,t);          weight=weight+step*t;          // When dec is small enough -        LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj << " Newton decr:"<<step.dot(grad)/2 << " linesearch: "<< -t ; -        if (step.dot(grad) < prec * obj) +        double nprec = step.dot(grad)/obj; +        ++iter; +        LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj << " Ndec/Obj:"<<nprec << " linesearch: "<< -t ; +        if (iter> maxiter) +        { +            LOG(INFO)<< "Maxiter reached"; +            break; +        } +        if (nprec < prec)              break;      }      return 0; diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp index e9b7f3d..72330d5 100644 --- a/tools/fileDataProvider.cpp +++ b/tools/fileDataProvider.cpp @@ -58,6 +58,9 @@ void RidFileDP::readEntries() {          d.addEntry(e);      }      pos = 0; +    std::vector<DataEntry*> & dat = d.getData(); +    while (pos<dat.size() && dat[pos]->rank!=-1 && dat[pos]->qid!="-1") +        ++pos;      qid = 1;      read = true;  } @@ -92,10 +95,9 @@ int RidFileDP::getDataSet(DataList &out){              }              out.addEntry(e);          } -    dat[pos]->qid=std::to_string(qid);      ++qid;      dat[pos]->rank=qid; -    while (pos<dat.size() && dat[pos]->rank!=-1) +    while (pos<dat.size() && (dat[pos]->rank!=-1 || dat[pos]->qid=="-1"))          ++pos;      if (pos==d.getSize())          eof = true; @@ -27,6 +27,7 @@ int train(DataProvider &dp) {      LOG(INFO)<<"Training started";      dp.getAllDataSet(D);      LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features"; +    LOG(INFO)<<"C: "<<C;      rsvm->train(D);      vector<double> L;      rsvm->predict(D,L); @@ -111,7 +112,8 @@ int main(int argc, char **argv) {              ("debug,d", "show debug messages")              ("model,m", po::value<string>(), "set input model file")              ("output,o", po::value<string>(), "set output model/prediction file") -            ("feature,i", po::value<string>(), "set input feature file"); +            ("feature,i", po::value<string>(), "set input feature file") +            ("c,c",po::value<double>(),"trades margin size against training error");      // Parsing program options      po::store(po::parse_command_line(argc, argv, desc), vm); @@ -130,6 +132,9 @@ int main(int argc, char **argv) {      mainFunc mainf;      if (vm.count("train")) { +        if (vm.count("c")) { +            C=vm["c"].as<double>(); +        }          mainf = &train;      }      else if (vm.count("validate")||vm.count("predict")) {  | 
