diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-05-19 13:26:20 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-05-19 13:26:20 +0800 |
commit | 3a45c9dd4620636645817d1cd57c8793169fae0b (patch) | |
tree | 1e6ad66099d8d6e667f34f02ffafe799fb337111 /model | |
parent | 615c70adb468de9b653930412ba87937f8e7f76b (diff) | |
download | ranksvm-3a45c9dd4620636645817d1cd57c8793169fae0b.tar.gz ranksvm-3a45c9dd4620636645817d1cd57c8793169fae0b.tar.bz2 ranksvm-3a45c9dd4620636645817d1cd57c8793169fae0b.zip |
added contraint r_1<r_2
Diffstat (limited to 'model')
-rw-r--r-- | model/ranksvmtn.cpp | 44 |
1 files changed, 25 insertions, 19 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 426de3a..2121008 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -49,21 +49,22 @@ int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorX VectorXd Ds(n); cal_Dw(D,s,Ds); VectorXd gamma(n); - for (int i=0;i<n;) + for (int i=0;i<n;i+=q) { double g=0; + // find B, cal A for (int j = q-1;j>=0;--j) if (corr[rank[i+j]]>0) gamma[rank[i+j]]=g; else g+=Ds[rank[i+j]]; g=0; - i+=q; - for (int j = q;j>0;--j) - if (corr[rank[i-j]]<0) - gamma[rank[i-j]]=g; + // find A, cal B + for (int j = 0;j<q;++j) + if (corr[rank[i+j]]<0) + gamma[rank[i+j]]=g; else - g+=Ds[rank[i-j]]; + g+=Ds[rank[i+j]]; } VectorXd tmp = alpha.cwiseProduct(Ds)-gamma; VectorXd res = VectorXd::Zero(D.getSize()); @@ -80,30 +81,33 @@ int cg_solve(RidList &D,const vector<int> &rank,const VectorXd &corr,const Vecto VectorXd Hs; cal_Hs(D,rank,corr,alph,x,Hs); VectorXd res = b - Hs; + // Non preconditioned version VectorXd p = res; + r_1 = res.dot(res); while (1) { - // Non preconditioned version + cal_Hs(D,rank,corr,alph,p,q); + alpha = r_1/p.dot(q); + x=x+p*alpha; + res=res-q*alpha; + ++iter; + r_2=r_1; r_1 = res.dot(res); - if (iter) LOG(INFO) << "CG iter "<<iter<<", r:"<<r_1; if (r_1<cg_prec) // Terminate condition break; + if (r_1>r_2) + { + LOG(INFO) << "CG forced termination by backward constraint, reverting"; + x=x-p*alpha; + } if (iter >= cg_maxiter) { LOG(INFO) << "CG forced termination by maxiter"; break; } - if (iter){ - beta = r_1 / r_2; - p = res + p * beta; - } - cal_Hs(D,rank,corr,alph,p,q); - alpha = r_1/p.dot(q); - x=x+p*alpha; - res=res-q*alpha; - ++iter; - r_2=r_1; + beta = r_1 / r_2; + p = res + p * beta; } return 0; } @@ -144,6 +148,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int int ed=i+q-1; ranksort(i,ed,rank,yt); double a=0,b=0; + // find A, cal B for (int j=i;j<=ed;++j) if (corr[rank[j]]<0) { @@ -156,6 +161,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int b+=yt[rank[j]]; } a=b=0; + // find B, cal A for (int j=ed;j>=i;--j) if (corr[rank[j]]>0) { @@ -226,6 +232,7 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){ VectorXd dw(n); VectorXd yt; VectorXd alpha,beta; + step = VectorXd::Zero(fsize); while (true) { cal_Dw(Data,weight,dw); @@ -236,7 +243,6 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){ VectorXd res = VectorXd::Zero(fsize); cal_Dtw(Data,tmp,res); grad = weight + C*res; - step = VectorXd::Zero(fsize); // Solve cg_solve(Data,rank,corr,alpha,grad,step); // do line search |