summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/ranksvm.cpp6
-rw-r--r--model/ranksvm.h4
-rw-r--r--model/ranksvmtn.cpp6
-rw-r--r--train.cpp27
4 files changed, 26 insertions, 17 deletions
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp
index 2f366d6..368e16d 100644
--- a/model/ranksvm.cpp
+++ b/model/ranksvm.cpp
@@ -15,11 +15,11 @@ int maxiter = 1; // max iteration count
double prec=1e-10; // precision
// conjugate gradient
double cg_prec=1e-10; // precision
-int cg_maxiter = 2000;
+int cg_maxiter = 100;
int ls_maxiter = 20;
// line search
-double line_prec=1e-10; // precision
-double line_turb=1e-15; // purturbation
+double ls_prec=1e-10; // precision
+double ls_turb=1e-15; // purturbation
int RSVM::saveModel(const string fname){
diff --git a/model/ranksvm.h b/model/ranksvm.h
index 41de0f4..a17e3c9 100644
--- a/model/ranksvm.h
+++ b/model/ranksvm.h
@@ -46,7 +46,7 @@ extern double cg_prec; // precision
extern int cg_maxiter; // not worth having a large number
extern int ls_maxiter;
// line search
-extern double line_prec; // precision
-extern double line_turb; // purturbation
+extern double ls_prec; // precision
+extern double ls_turb; // perturbation
#endif \ No newline at end of file
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 9beed65..01fcb83 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -192,12 +192,12 @@ int line_search(const VectorXd &w,RidList &D,const VectorXd &corr,const VectorXd
g = grad.dot(step);
cal_Hs(D,rank,corr,alpha,step,Hs);
h = Hs.dot(step);
- g=g+line_turb;
- h = h+line_turb;
+ g=g+ls_turb;
+ h = h+ls_turb;
t=t-g/h;
++iter;
LOG(INFO) << "line search iter "<<iter<<", prec:"<<g*g/h;
- if (g*g/h<line_prec)
+ if (g*g/h<ls_prec)
break;
if (iter >= ls_maxiter)
{
diff --git a/train.cpp b/train.cpp
index 4b8439a..07f9edc 100644
--- a/train.cpp
+++ b/train.cpp
@@ -27,7 +27,9 @@ int train(DataProvider &dp) {
LOG(INFO)<<"Training started";
dp.getAllDataSet(D);
LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
- LOG(INFO)<<"C: "<<C;
+ LOG(INFO)<<"C: "<<C<<" ,iter: "<<maxiter<<" ,prec: "<<prec;
+ LOG(INFO)<<"cg_maxiter: "<<cg_maxiter<<" ,cg_prec:"<<cg_prec<<" ,ls_maxiter: "<<ls_maxiter<<" ,ls_prec: "<<ls_prec;
+
rsvm->train(D);
LOG(INFO)<<"Training finished,saving model";
@@ -109,7 +111,13 @@ int main(int argc, char **argv) {
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
("feature,i", po::value<string>(), "set input feature file")
- ("c,c",po::value<double>(),"trades margin size against training error");
+ ("c,c",po::value<double>(),"trades margin size against training error")
+ ("iter",po::value<int>(),"iter main")
+ ("prec",po::value<double>(),"prec main")
+ ("cg_iter",po::value<int>(),"iter conjugate gradient")
+ ("cg_prec",po::value<double>(),"prec conjugate gradient")
+ ("ls_iter",po::value<int>(),"iter line search")
+ ("ls_prec",po::value<double>(),"prec line search");
// Parsing program options
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -127,14 +135,15 @@ int main(int argc, char **argv) {
el::Loggers::reconfigureLogger("default", defaultConf);
mainFunc mainf;
- if (vm.count("single"))
- RidList::single=true;
- else
- RidList::single=false;
+ RidList::single=vm.count("single")>0;
if (vm.count("train")) {
- if (vm.count("c")) {
- C=vm["c"].as<double>();
- }
+ if (vm.count("c")) { C=vm["c"].as<double>(); }
+ if (vm.count("iter")) { maxiter=vm["iter"].as<int>(); }
+ if (vm.count("prec")) { prec=vm["prec"].as<double>(); }
+ if (vm.count("cg_iter")) { cg_maxiter=vm["cg_iter"].as<int>(); }
+ if (vm.count("cg_prec")) { cg_prec=vm["cg_prec"].as<double>(); }
+ if (vm.count("ls_iter")) { ls_maxiter=vm["ls_iter"].as<int>(); }
+ if (vm.count("ls_prec")) { ls_prec=vm["ls_prec"].as<double>(); }
mainf = &train;
}
else if (vm.count("validate")||vm.count("predict")) {