summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.cpp2
-rw-r--r--model/rankaccu.cpp1
-rw-r--r--model/ranksvmtn.cpp27
-rw-r--r--tools/dataProvider.h11
-rw-r--r--tools/fileDataProvider.h6
5 files changed, 36 insertions, 11 deletions
diff --git a/main.cpp b/main.cpp
index 5c977d2..4f080bb 100644
--- a/main.cpp
+++ b/main.cpp
@@ -25,7 +25,7 @@ int train(DataProvider &dp) {
DataList D;
LOG(INFO)<<"Training started";
- dp.getAllData(D);
+ dp.getAllDataSet(D);
LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
rsvm->train(D);
vector<double> L;
diff --git a/model/rankaccu.cpp b/model/rankaccu.cpp
index 910e3e0..73b5d18 100644
--- a/model/rankaccu.cpp
+++ b/model/rankaccu.cpp
@@ -134,6 +134,7 @@ void rank_CMC(DataList &D,const std::vector<double> pred,CMC & cmc) {
for (int k=j;k<=i;++k)
if (orig[pred_rank[k]]>0)
{
+ cout<<pred_rank[k]<<" "<<pred[k+1]<<" "<< k <<" "<< j<<endl;
cmc.addEntry(k-j);
break; // account only for the first match;
}
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index b82ce64..c2ca639 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -6,12 +6,13 @@
using namespace std;
using namespace Eigen;
-const double C=1e-2; // Compensating & scaling
+const double C=1e-5; // Compensating & scaling
// Main terminating criteria
const int maxiter = 10; // max iteration count
-const double prec=1e-4; // precision
+const double prec=1e-10; // precision
// conjugate gradient
const double cg_prec=1e-10; // precision
+const int cg_maxiter = 30;
// line search
const double line_prec=1e-10; // precision
const double line_turb=1e-15; // purturbation
@@ -64,6 +65,11 @@ int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,cons
x=x+p*alpha;
res=res-q*alpha;
++step;
+ if (step > cg_maxiter)
+ {
+ LOG(INFO) << "CG force terminated by maxiter";
+ break;
+ }
r_2=r_1;
}
return 0;
@@ -138,6 +144,7 @@ int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const v
VectorXd grad;
VectorXd Hs;
vector<int> rank(D.rows());
+ int iter = 0;
for (int i=0;i<A1.size();++i)
Xd(i) = Dd(A1[i])-Dd(A2[i]);
@@ -157,6 +164,12 @@ int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const v
t=t-g/h;
if (g*g/h<line_prec)
break;
+ ++iter;
+ if (iter > cg_maxiter)
+ {
+ LOG(INFO) << "line search force terminated by maxiter";
+ break;
+ }
}
return 0;
}
@@ -179,14 +192,12 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A
iter+=1;
if (iter> maxiter)
{
- LOG(INFO)<< "Maxiter :"<<maxiter<<" reached";
+ LOG(INFO)<< "Maxiter reached";
break;
}
dw = D*weight;
-
cal_alpha_beta(dw,corr,A1,A2,rank,yt,alpha,beta);
-
// Generate support vector matrix sv & gradient
obj = (weight.dot(weight) + C*(alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt)))/2;//
grad = weight + C*(D.transpose()*(alpha.cwiseProduct(yt)-beta));
@@ -213,8 +224,10 @@ int RSVMTN::train(DataList &D){
vector<DataEntry*> &dat = D.getData();
for (i=0;i<D.getSize();++i) {
corr(i)=(dat[i])->rank>0?0.5:-0.5;
- for (j = 0; j < D.getfSize(); ++j)
- Data(i, j) = dat[i]->feature(j);
+
+ for (j = 0; j < D.getfSize(); ++j){
+ Data(i, j) = dat[i]->feature(j);}
+
}
i=j=0;
while (i<D.getSize())
diff --git a/tools/dataProvider.h b/tools/dataProvider.h
index 1d430e4..028980e 100644
--- a/tools/dataProvider.h
+++ b/tools/dataProvider.h
@@ -40,7 +40,11 @@ public:
DataEntry* dat = new DataEntry;
dat->rank = d->rank;
dat->qid = d->qid;
- dat->feature = d->feature;
+ dat->feature.resize(d->feature.rows());
+ for (int i=0;i<d->feature.rows();++i)
+ {
+ dat->feature(i)=d->feature(i);
+ }
return dat;
}
inline std::vector<DataEntry*>& getData(){
@@ -59,7 +63,7 @@ public:
DataProvider():eof(false){};
bool EOFile(){return eof;}
- void getAllData(DataList &out){\
+ void getAllDataSet(DataList &out){\
out.clear();
DataList buf;
while (!EOFile())
@@ -68,9 +72,12 @@ public:
// won't work as data are discarded with every call to getDataSet
// out.getData().insert(out.getData().end(),buf.getData().begin(),buf.getData().end());
for (int i=0;i<buf.getSize();++i)
+ {
out.addEntry(out.copyEntry(buf.getData()[i]));
+ }
out.setfSize(buf.getfSize());
}
+ buf.clear();
}
virtual int getDataSet(DataList &out) = 0;
virtual int open()=0;
diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h
index c4f6a4a..8ebda20 100644
--- a/tools/fileDataProvider.h
+++ b/tools/fileDataProvider.h
@@ -82,9 +82,11 @@ public:
}
pos = 0;
qid = 1;
+ read = true;
}
out.clear();
fsize = d.getfSize();
+ out.setfSize(fsize);
std::vector<DataEntry*> & dat = d.getData();
for (int i=0;i<d.getSize();++i)
if (i!=pos)
@@ -94,6 +96,7 @@ public:
e = new DataEntry;
e->rank=1;
dat[i]->qid=std::to_string(qid);
+ dat[i]->rank=qid;
}
else
{
@@ -103,12 +106,13 @@ public:
e->feature.resize(d.getfSize());
e->qid=std::to_string(qid);
for (int j = 0; j < fsize; ++j) {
- e->feature(i) = fabs(dat[i]->feature(j) -dat[pos]->feature(j));
+ e->feature(j) = fabs(dat[i]->feature(j) -dat[pos]->feature(j));
}
out.addEntry(e);
}
dat[pos]->qid=std::to_string(qid);
++qid;
+ dat[pos]->rank=qid;
while (pos<dat.size() && dat[pos]->rank!=-1)
++pos;
if (pos==d.getSize())