summaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-05-19 13:26:20 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-05-19 13:26:20 +0800
commit3a45c9dd4620636645817d1cd57c8793169fae0b (patch)
tree1e6ad66099d8d6e667f34f02ffafe799fb337111 /model
parent615c70adb468de9b653930412ba87937f8e7f76b (diff)
downloadranksvm-3a45c9dd4620636645817d1cd57c8793169fae0b.tar.gz
ranksvm-3a45c9dd4620636645817d1cd57c8793169fae0b.tar.bz2
ranksvm-3a45c9dd4620636645817d1cd57c8793169fae0b.zip
added contraint r_1<r_2
Diffstat (limited to 'model')
-rw-r--r--model/ranksvmtn.cpp44
1 files changed, 25 insertions, 19 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 426de3a..2121008 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -49,21 +49,22 @@ int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorX
VectorXd Ds(n);
cal_Dw(D,s,Ds);
VectorXd gamma(n);
- for (int i=0;i<n;)
+ for (int i=0;i<n;i+=q)
{
double g=0;
+ // find B, cal A
for (int j = q-1;j>=0;--j)
if (corr[rank[i+j]]>0)
gamma[rank[i+j]]=g;
else
g+=Ds[rank[i+j]];
g=0;
- i+=q;
- for (int j = q;j>0;--j)
- if (corr[rank[i-j]]<0)
- gamma[rank[i-j]]=g;
+ // find A, cal B
+ for (int j = 0;j<q;++j)
+ if (corr[rank[i+j]]<0)
+ gamma[rank[i+j]]=g;
else
- g+=Ds[rank[i-j]];
+ g+=Ds[rank[i+j]];
}
VectorXd tmp = alpha.cwiseProduct(Ds)-gamma;
VectorXd res = VectorXd::Zero(D.getSize());
@@ -80,30 +81,33 @@ int cg_solve(RidList &D,const vector<int> &rank,const VectorXd &corr,const Vecto
VectorXd Hs;
cal_Hs(D,rank,corr,alph,x,Hs);
VectorXd res = b - Hs;
+ // Non preconditioned version
VectorXd p = res;
+ r_1 = res.dot(res);
while (1)
{
- // Non preconditioned version
+ cal_Hs(D,rank,corr,alph,p,q);
+ alpha = r_1/p.dot(q);
+ x=x+p*alpha;
+ res=res-q*alpha;
+ ++iter;
+ r_2=r_1;
r_1 = res.dot(res);
- if (iter)
LOG(INFO) << "CG iter "<<iter<<", r:"<<r_1;
if (r_1<cg_prec) // Terminate condition
break;
+ if (r_1>r_2)
+ {
+ LOG(INFO) << "CG forced termination by backward constraint, reverting";
+ x=x-p*alpha;
+ }
if (iter >= cg_maxiter)
{
LOG(INFO) << "CG forced termination by maxiter";
break;
}
- if (iter){
- beta = r_1 / r_2;
- p = res + p * beta;
- }
- cal_Hs(D,rank,corr,alph,p,q);
- alpha = r_1/p.dot(q);
- x=x+p*alpha;
- res=res-q*alpha;
- ++iter;
- r_2=r_1;
+ beta = r_1 / r_2;
+ p = res + p * beta;
}
return 0;
}
@@ -144,6 +148,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int
int ed=i+q-1;
ranksort(i,ed,rank,yt);
double a=0,b=0;
+ // find A, cal B
for (int j=i;j<=ed;++j)
if (corr[rank[j]]<0)
{
@@ -156,6 +161,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int
b+=yt[rank[j]];
}
a=b=0;
+ // find B, cal A
for (int j=ed;j>=i;--j)
if (corr[rank[j]]>0)
{
@@ -226,6 +232,7 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){
VectorXd dw(n);
VectorXd yt;
VectorXd alpha,beta;
+ step = VectorXd::Zero(fsize);
while (true)
{
cal_Dw(Data,weight,dw);
@@ -236,7 +243,6 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){
VectorXd res = VectorXd::Zero(fsize);
cal_Dtw(Data,tmp,res);
grad = weight + C*res;
- step = VectorXd::Zero(fsize);
// Solve
cg_solve(Data,rank,corr,alpha,grad,step);
// do line search