summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/ranksvm.h2
-rw-r--r--model/ranksvmtn.cpp126
-rw-r--r--model/ranksvmtn.h2
-rw-r--r--tools/dataProvider.h81
-rw-r--r--tools/fileDataProvider.cpp11
-rw-r--r--tools/fileDataProvider.h2
-rw-r--r--train.cpp4
7 files changed, 144 insertions, 84 deletions
diff --git a/model/ranksvm.h b/model/ranksvm.h
index aa5e1ca..58cd7f0 100644
--- a/model/ranksvm.h
+++ b/model/ranksvm.h
@@ -25,7 +25,7 @@ protected:
SVMModel model;
int fsize;
public:
- virtual int train(DataList &D)=0;
+ virtual int train(RidList &D)=0;
virtual int predict(DataList &D,std::vector<double> &res)=0;
// TODO Not sure how to construct this
// Possible solution: generate a nxn matrix each row contains the sorted list of ranker result.
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 01d9851..d3ef3af 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -17,37 +17,44 @@ const int cg_maxiter = 30;
const double line_prec=1e-10; // precision
const double line_turb=1e-15; // purturbation
-int cal_Hs(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alpha,const vector<int> &A1,const vector<int> &A2,const VectorXd s,VectorXd &Hs)
+int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alpha,const VectorXd s,VectorXd &Hs)
{
Hs = VectorXd::Zero(s.rows());
- VectorXd Ds=D*s;
- VectorXd gamma(D.rows());
- for (int i=0;i<A1.size();++i)
+ VectorXd Ds(D.getSize());
+ for (int i=0;i<D.getSize();++i)
+ Ds(i) = D.getVec(i).dot(s);
+ VectorXd gamma(D.getSize());
+ for (int i=0;i<D.getSize();)
{
double g=0;
- for (int j = A1[i];j<=A2[i];++j)
- if (corr[rank[j]]<0)
- gamma[rank[j]]=g;
+ for (int j = D.getqSize()-1;j>=0;--j)
+ if (corr[rank[i+j]]>0)
+ gamma[rank[i+j]]=g;
else
- g+=Ds[rank[j]];
+ g+=Ds[rank[i+j]];
g=0;
- for (int j = A2[i];j>=A1[i];--j)
- if (corr[rank[j]]>0)
- gamma[rank[j]]=g;
+ i+=D.getqSize();
+ for (int j = D.getqSize();j>0;--j)
+ if (corr[rank[i-j]]<0)
+ gamma[rank[i-j]]=g;
else
- g+=Ds[rank[j]];
+ g+=Ds[rank[i-j]];
}
- Hs = s + C*(D.transpose()*(alpha.cwiseProduct(Ds) - gamma));
+ VectorXd tmp = alpha.cwiseProduct(Ds)-gamma;
+ VectorXd res = 0*s;
+ for (int i=0;i<D.getSize();++i)
+ res = res + D.getVec(i)*tmp[i];
+ Hs = s + C*res;
return 0;
}
-int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alph,const vector<int> &A1,const vector<int> &A2,const VectorXd &b, VectorXd &x)
+int cg_solve(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alph,const VectorXd &b, VectorXd &x)
{
double alpha,beta,r_1,r_2;
int iter=0;
VectorXd q;
VectorXd Hs;
- cal_Hs(D,rank,corr,alph,A1,A2,x,Hs);
+ cal_Hs(D,rank,corr,alph,x,Hs);
VectorXd res = b - Hs;
VectorXd p = res;
while (1)
@@ -65,7 +72,7 @@ int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,cons
beta = r_1 / r_2;
p = res + p * beta;
}
- cal_Hs(D,rank,corr,alph,A1,A2,p,q);
+ cal_Hs(D,rank,corr,alph,p,q);
alpha = r_1/p.dot(q);
x=x+p*alpha;
res=res-q*alpha;
@@ -98,18 +105,19 @@ void ranksort(int l,int r,vector<int> &rank,VectorXd &ref)
ranksort(i,r,rank,ref);
}
-int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,const vector<int> &A1,const vector<int> &A2,vector<int> &rank,VectorXd &yt,VectorXd &alpha,VectorXd &beta)
+int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int> &rank,VectorXd &yt,VectorXd &alpha,VectorXd &beta)
{
long n = dw.rows();
yt = dw - corr;
alpha=VectorXd::Zero(n);
beta=VectorXd::Zero(n);
for (int i=0;i<n;++i) rank[i]=i;
- for (int i=0;i<A1.size();++i)
+ for (int i=0;i<D.getSize();i+=D.getqSize())
{
- ranksort(A1[i],A2[i],rank,yt);
+ int ed=i+D.getqSize()-1;
+ ranksort(i,ed,rank,yt);
double a=0,b=0;
- for (int j=A1[i];j<=A2[i];++j)
+ for (int j=i;j<=ed;++j)
if (corr[rank[j]]<0)
{
alpha[rank[j]]=a;
@@ -121,7 +129,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,const vector<int> &A1
b+=yt[rank[j]];
}
a=b=0;
- for (int j=A2[i];j>=A1[i];--j)
+ for (int j=ed;j>=i;--j)
if (corr[rank[j]]>0)
{
alpha[rank[j]]=a;
@@ -136,28 +144,33 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,const vector<int> &A1
}
// line search using newton method
-int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const vector<int> &A1,const vector<int> &A2,const VectorXd &step,double &t)
+int line_search(const VectorXd &w,RidList &D,const VectorXd &corr,const VectorXd &step,double &t)
{
- VectorXd Dd = D*step;
- VectorXd Xd = VectorXd::Zero(A1.size());
+ VectorXd Dd(D.getSize());
+ for (int i=0;i<D.getSize();++i)
+ Dd(i) = D.getVec(i).dot(step);
VectorXd alpha,beta,yt;
VectorXd grad;
VectorXd Hs;
- vector<int> rank(D.rows());
+ int n=D.getSize();
+ vector<int> rank(D.getSize());
int iter = 0;
- for (int i=0;i<A1.size();++i)
- Xd(i) = Dd(A1[i])-Dd(A2[i]);
double g,h;
t = 0;
while (1)
{
grad=w+t*step;
- Dd = D*(w + t*step);
- cal_alpha_beta(Dd,corr,A1,A2,rank,yt,alpha,beta);
- grad = grad + C*(D.transpose()*(alpha.cwiseProduct(yt)-beta));
+ for (int i=0;i<D.getSize();++i)
+ Dd(i) = D.getVec(i).dot(grad);
+ cal_alpha_beta(Dd,corr,D,rank,yt,alpha,beta);
+ VectorXd tmp = alpha.cwiseProduct(yt)-beta;
+ VectorXd res = 0*grad;
+ for (int i=0;i<n;++i)
+ res = res + D.getVec(i)*tmp[i];
+ grad = grad + C*res;
g = grad.dot(step);
- cal_Hs(D,rank,corr,alpha,A1,A2,step,Hs);
+ cal_Hs(D,rank,corr,alpha,step,Hs);
h = Hs.dot(step);
g=g+line_turb;
h = h+line_turb;
@@ -174,17 +187,17 @@ int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const v
return 0;
}
-int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const VectorXd &corr,VectorXd &weight){
+int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){
int iter = 0;
- long n=D.rows();
- LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Query size:" << A1.size();
+ long n=Data.getSize();
+ LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Query size:" << Data.getuSize();
VectorXd grad(fsize);
VectorXd step(fsize);
vector<int> rank(n);
double obj,t;
- VectorXd dw = D*weight;
+ VectorXd dw(n);
VectorXd yt;
VectorXd alpha,beta;
while (true)
@@ -196,16 +209,21 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A
break;
}
- dw = D*weight;
- cal_alpha_beta(dw,corr,A1,A2,rank,yt,alpha,beta);
+ for (int i=0;i<n;++i)
+ dw(i) = Data.getVec(i).dot(weight);
+ cal_alpha_beta(dw,corr,Data,rank,yt,alpha,beta);
// Generate support vector matrix sv & gradient
- obj = (weight.dot(weight) + C*(alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt)))/2;//
- grad = weight + C*(D.transpose()*(alpha.cwiseProduct(yt)-beta));
+ obj = (weight.dot(weight) + C*(alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt)))/2;
+ VectorXd tmp = alpha.cwiseProduct(yt)-beta;
+ VectorXd res = 0*weight;
+ for (int i=0;i<n;++i)
+ res = res + Data.getVec(i)*tmp[i];
+ grad = weight + C*res;
step = grad*0;
// Solve
- cg_solve(D,rank,corr,alpha,A1,A2,grad,step);
+ cg_solve(Data,rank,corr,alpha,grad,step);
// do line search
- line_search(weight,D,corr,A1,A2,step,t);
+ line_search(weight,Data,corr,step,t);
weight=weight+step*t;
// When dec is small enough
LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj << " Newton decr:"<<step.dot(grad)/2 << " linesearch: "<< -t ;
@@ -215,32 +233,14 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A
return 0;
}
-int RSVMTN::train(DataList &D){
- MatrixXd Data(D.getSize(),D.getfSize());
+int RSVMTN::train(RidList &D){
VectorXd corr(D.getSize());
vector<int> A1,A2;
int i,j;
LOG(INFO)<<"Processing input";
- vector<DataEntry*> &dat = D.getData();
- for (i=0;i<D.getSize();++i) {
- corr(i)=(dat[i])->rank>0?0.5:-0.5;
-
- for (j = 0; j < D.getfSize(); ++j){
- Data(i, j) = dat[i]->feature(j);}
-
- }
- i=j=0;
- while (i<D.getSize())
- {
- if ((i+1 == D.getSize())|| dat[i]->qid!=dat[i+1]->qid)
- {
- A1.push_back(j);
- A2.push_back(i);
- j = i+1;
- }
- ++i;
- }
- train_orig(fsize,Data,A1,A2,corr,model.weight);
+ for (i=0;i<D.getSize();++i)
+ corr(i)=D.getL(i)>0?0.5:-0.5;
+ train_orig(fsize,D,corr,model.weight);
return 0;
};
diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h
index 4074781..c98e581 100644
--- a/model/ranksvmtn.h
+++ b/model/ranksvmtn.h
@@ -12,7 +12,7 @@ public:
{
return "TN";
};
- virtual int train(DataList &D);
+ virtual int train(RidList &D);
virtual int predict(DataList &D,std::vector<double> &res);
};
diff --git a/tools/dataProvider.h b/tools/dataProvider.h
index 028980e..a3f3d34 100644
--- a/tools/dataProvider.h
+++ b/tools/dataProvider.h
@@ -55,6 +55,70 @@ public:
}
};
+class RidList{
+private:
+ int n;
+ std::vector<DataEntry*> uniq;
+ std::vector<DataEntry*> other;
+public:
+ void clear(){
+ uniq.clear();
+ other.clear();
+ }
+ void setfSize(int fsize){n=fsize;}
+ int getfSize(){return n;}
+ void addEntry(DataEntry* d){
+ int ext=false;
+ for (int i=0;i<uniq.size();++i)
+ if (uniq[i]->qid==d->qid)
+ {
+ ext = true;
+ d->rank = i;
+ }
+ if (ext)
+ other.push_back(d);
+ else
+ uniq.push_back(d);
+ }
+ int getqSize()
+ {
+ return (int)(uniq.size()+other.size()-1);
+ }
+ int getuSize()
+ {
+ return (int)uniq.size();
+ }
+ int getSize()
+ {
+ return getuSize()*getqSize();
+ }
+ Eigen::VectorXd getVec(int x){
+ int a,b,n=getqSize();
+ a=x/n;
+ b=x%n;
+ Eigen::VectorXd vec = uniq[a]->feature;
+ if (b<a)
+ vec=vec-uniq[b]->feature;
+ else
+ if (b<uniq.size()-1)
+ vec=vec-uniq[b+1]->feature;
+ else
+ vec=vec-other[b-uniq.size()+1]->feature;
+ return vec.cwiseAbs();
+ };
+ double getL(int x){
+ int a,b,n=(int)(uniq.size()+other.size()-1);
+ a=x/n;
+ b=x%n;
+ if (b<uniq.size()-1)
+ return -1;
+ else
+ if (std::fabs(other[b-uniq.size()+1]->rank - a) < 1e-5)
+ return 1;
+ return -1;
+ };
+};
+
class DataProvider //Virtual base class for data input
{
protected:
@@ -63,22 +127,7 @@ public:
DataProvider():eof(false){};
bool EOFile(){return eof;}
- void getAllDataSet(DataList &out){\
- out.clear();
- DataList buf;
- while (!EOFile())
- {
- getDataSet(buf);
- // won't work as data are discarded with every call to getDataSet
- // out.getData().insert(out.getData().end(),buf.getData().begin(),buf.getData().end());
- for (int i=0;i<buf.getSize();++i)
- {
- out.addEntry(out.copyEntry(buf.getData()[i]));
- }
- out.setfSize(buf.getfSize());
- }
- buf.clear();
- }
+ virtual void getAllDataSet(RidList &out) = 0;
virtual int getDataSet(DataList &out) = 0;
virtual int open()=0;
virtual int close()=0;
diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp
index e9b7f3d..1ff0279 100644
--- a/tools/fileDataProvider.cpp
+++ b/tools/fileDataProvider.cpp
@@ -170,4 +170,15 @@ void RidFileDP::take(int n,vector<DataEntry*> &a,vector<DataEntry*> &b)
b.push_back(tmp[i]);
scrambler(a);
scrambler(b);
+}
+
+void RidFileDP::getAllDataSet(RidList &out){
+ DataEntry *e;
+ if (!read)
+ readEntries();
+ out.clear();
+ std::vector<DataEntry*> &dat = d.getData();
+ for (int i=0;i<dat.size();++i)
+ out.addEntry(dat[i]);
+ out.setfSize(d.getfSize());
} \ No newline at end of file
diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h
index 7bea92d..567c8e2 100644
--- a/tools/fileDataProvider.h
+++ b/tools/fileDataProvider.h
@@ -17,6 +17,7 @@ private:
public:
FileDP(std::string fn=""):fname(fn){};
virtual int getDataSet(DataList &out);
+ virtual void getAllDataSet(RidList &out){ LOG(FATAL)<<"getAllDataSet for normal FileDP not implemented";};
virtual int open(){fin.open(fname); eof=false;return 0;};
virtual int close(){fin.close();return 0;};
};
@@ -37,6 +38,7 @@ public:
void readEntries();
int getfSize() { if(!read) readEntries(); return d.getfSize();};
int getpSize();
+ virtual void getAllDataSet(RidList &out);
virtual int getDataSet(DataList &out);
virtual int open(){fin.open(fname); eof=false;return 0;};
virtual int close(){fin.close(); d.clear();return 0;};
diff --git a/train.cpp b/train.cpp
index a0c62a9..05787c9 100644
--- a/train.cpp
+++ b/train.cpp
@@ -22,14 +22,12 @@ int train(DataProvider &dp) {
rsvm = RSVM::loadModel(vm["model"].as<string>());
dp.open();
- DataList D;
+ RidList D;
LOG(INFO)<<"Training started";
dp.getAllDataSet(D);
LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
rsvm->train(D);
- vector<double> L;
- rsvm->predict(D,L);
LOG(INFO)<<"Training finished,saving model";