summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-05-18 16:17:01 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-05-18 16:17:01 +0800
commit653bee4b89131a997043a074d51c28dedb907f5c (patch)
treedd07194c85199aa1a3c38701bcc5d595f3d6f2b3
parent20587ac550cfcb2d7b3d6ec16e46ba1a8d0af869 (diff)
downloadranksvm-mbk.tar.gz
ranksvm-mbk.tar.bz2
ranksvm-mbk.zip
output updatembk
-rw-r--r--model/ranksvm.cpp2
-rw-r--r--model/ranksvm.h2
-rw-r--r--model/ranksvmtn.cpp29
-rw-r--r--tools/fileDataProvider.cpp6
-rw-r--r--train.cpp7
5 files changed, 33 insertions, 13 deletions
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp
index 7ee72ac..dd2d9e4 100644
--- a/model/ranksvm.cpp
+++ b/model/ranksvm.cpp
@@ -8,6 +8,8 @@
using namespace Eigen;
using namespace std;
+double C=1e-4; // Compensating & scaling
+
int RSVM::saveModel(const string fname){
std::ofstream fout(fname.c_str());
diff --git a/model/ranksvm.h b/model/ranksvm.h
index aa5e1ca..26a3f6d 100644
--- a/model/ranksvm.h
+++ b/model/ranksvm.h
@@ -37,4 +37,6 @@ public:
int setModel(const SVMModel &model);
};
+extern double C;
+
#endif \ No newline at end of file
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 01d9851..0e41861 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -6,13 +6,13 @@
using namespace std;
using namespace Eigen;
-const double C=1e-4; // Compensating & scaling
// Main terminating criteria
-const int maxiter = 20; // max iteration count
+const int maxiter = 40; // max iteration count
const double prec=1e-10; // precision
// conjugate gradient
const double cg_prec=1e-10; // precision
-const int cg_maxiter = 30;
+const int cg_maxiter = 5; // not worth having a large number
+const int ls_maxiter = 10;
// line search
const double line_prec=1e-10; // precision
const double line_turb=1e-15; // purturbation
@@ -54,11 +54,13 @@ int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,cons
{
// Non preconditioned version
r_1 = res.dot(res);
+ if (iter)
+ LOG(INFO) << "CG iter "<<iter<<", r:"<<r_1;
if (r_1<cg_prec) // Terminate condition
break;
- if (iter > cg_maxiter)
+ if (iter >= cg_maxiter)
{
- LOG(INFO) << "CG forced termination by maxiter, r:"<<r_1;
+ LOG(INFO) << "CG forced termination by maxiter";
break;
}
if (iter){
@@ -162,10 +164,11 @@ int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const v
g=g+line_turb;
h = h+line_turb;
t=t-g/h;
+ ++iter;
+ LOG(INFO) << "line search iter "<<iter<<", prec:"<<g*g/h;
if (g*g/h<line_prec)
break;
- ++iter;
- if (iter > cg_maxiter)
+ if (iter >= ls_maxiter)
{
LOG(INFO) << "line search forced termination by maxiter, prec:"<<g*g/h;
break;
@@ -189,7 +192,6 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A
VectorXd alpha,beta;
while (true)
{
- iter+=1;
if (iter> maxiter)
{
LOG(INFO)<< "Maxiter reached";
@@ -208,8 +210,15 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A
line_search(weight,D,corr,A1,A2,step,t);
weight=weight+step*t;
// When dec is small enough
- LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj << " Newton decr:"<<step.dot(grad)/2 << " linesearch: "<< -t ;
- if (step.dot(grad) < prec * obj)
+ double nprec = step.dot(grad)/obj;
+ ++iter;
+ LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj << " Ndec/Obj:"<<nprec << " linesearch: "<< -t ;
+ if (iter> maxiter)
+ {
+ LOG(INFO)<< "Maxiter reached";
+ break;
+ }
+ if (nprec < prec)
break;
}
return 0;
diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp
index e9b7f3d..72330d5 100644
--- a/tools/fileDataProvider.cpp
+++ b/tools/fileDataProvider.cpp
@@ -58,6 +58,9 @@ void RidFileDP::readEntries() {
d.addEntry(e);
}
pos = 0;
+ std::vector<DataEntry*> & dat = d.getData();
+ while (pos<dat.size() && dat[pos]->rank!=-1 && dat[pos]->qid!="-1")
+ ++pos;
qid = 1;
read = true;
}
@@ -92,10 +95,9 @@ int RidFileDP::getDataSet(DataList &out){
}
out.addEntry(e);
}
- dat[pos]->qid=std::to_string(qid);
++qid;
dat[pos]->rank=qid;
- while (pos<dat.size() && dat[pos]->rank!=-1)
+ while (pos<dat.size() && (dat[pos]->rank!=-1 || dat[pos]->qid=="-1"))
++pos;
if (pos==d.getSize())
eof = true;
diff --git a/train.cpp b/train.cpp
index a0c62a9..bae88f3 100644
--- a/train.cpp
+++ b/train.cpp
@@ -27,6 +27,7 @@ int train(DataProvider &dp) {
LOG(INFO)<<"Training started";
dp.getAllDataSet(D);
LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
+ LOG(INFO)<<"C: "<<C;
rsvm->train(D);
vector<double> L;
rsvm->predict(D,L);
@@ -111,7 +112,8 @@ int main(int argc, char **argv) {
("debug,d", "show debug messages")
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
- ("feature,i", po::value<string>(), "set input feature file");
+ ("feature,i", po::value<string>(), "set input feature file")
+ ("c,c",po::value<double>(),"trades margin size against training error");
// Parsing program options
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -130,6 +132,9 @@ int main(int argc, char **argv) {
mainFunc mainf;
if (vm.count("train")) {
+ if (vm.count("c")) {
+ C=vm["c"].as<double>();
+ }
mainf = &train;
}
else if (vm.count("validate")||vm.count("predict")) {