From 3a45c9dd4620636645817d1cd57c8793169fae0b Mon Sep 17 00:00:00 2001
From: Joe Zhao <ztuowen@gmail.com>
Date: Tue, 19 May 2015 13:26:20 +0800
Subject: added contraint r_1<r_2

---
 model/ranksvmtn.cpp | 44 +++++++++++++++++++++++++-------------------
 1 file changed, 25 insertions(+), 19 deletions(-)

diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 426de3a..2121008 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -49,21 +49,22 @@ int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorX
     VectorXd Ds(n);
     cal_Dw(D,s,Ds);
     VectorXd gamma(n);
-    for (int i=0;i<n;)
+    for (int i=0;i<n;i+=q)
     {
         double g=0;
+        // find B, cal A
         for (int j = q-1;j>=0;--j)
             if (corr[rank[i+j]]>0)
                 gamma[rank[i+j]]=g;
             else
                 g+=Ds[rank[i+j]];
         g=0;
-        i+=q;
-        for (int j = q;j>0;--j)
-            if (corr[rank[i-j]]<0)
-                gamma[rank[i-j]]=g;
+        // find A, cal B
+        for (int j = 0;j<q;++j)
+            if (corr[rank[i+j]]<0)
+                gamma[rank[i+j]]=g;
             else
-                g+=Ds[rank[i-j]];
+                g+=Ds[rank[i+j]];
     }
     VectorXd tmp = alpha.cwiseProduct(Ds)-gamma;
     VectorXd res = VectorXd::Zero(D.getSize());
@@ -80,30 +81,33 @@ int cg_solve(RidList &D,const vector<int> &rank,const VectorXd &corr,const Vecto
     VectorXd Hs;
     cal_Hs(D,rank,corr,alph,x,Hs);
     VectorXd res = b - Hs;
+    // Non preconditioned version
     VectorXd p = res;
+    r_1 = res.dot(res);
     while (1)
     {
-        // Non preconditioned version
+        cal_Hs(D,rank,corr,alph,p,q);
+        alpha = r_1/p.dot(q);
+        x=x+p*alpha;
+        res=res-q*alpha;
+        ++iter;
+        r_2=r_1;
         r_1 = res.dot(res);
-        if (iter)
         LOG(INFO) << "CG iter "<<iter<<", r:"<<r_1;
         if (r_1<cg_prec) // Terminate condition
             break;
+        if (r_1>r_2)
+        {
+            LOG(INFO) << "CG forced termination by backward constraint, reverting";
+            x=x-p*alpha;
+        }
         if (iter >= cg_maxiter)
         {
             LOG(INFO) << "CG forced termination by maxiter";
             break;
         }
-        if (iter){
-            beta = r_1 / r_2;
-            p = res + p * beta;
-        }
-        cal_Hs(D,rank,corr,alph,p,q);
-        alpha = r_1/p.dot(q);
-        x=x+p*alpha;
-        res=res-q*alpha;
-        ++iter;
-        r_2=r_1;
+        beta = r_1 / r_2;
+        p = res + p * beta;
     }
     return 0;
 }
@@ -144,6 +148,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int
         int ed=i+q-1;
         ranksort(i,ed,rank,yt);
         double a=0,b=0;
+        // find A, cal B
         for (int j=i;j<=ed;++j)
             if (corr[rank[j]]<0)
             {
@@ -156,6 +161,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int
                 b+=yt[rank[j]];
             }
         a=b=0;
+        // find B, cal A
         for (int j=ed;j>=i;--j)
             if (corr[rank[j]]>0)
             {
@@ -226,6 +232,7 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){
     VectorXd dw(n);
     VectorXd yt;
     VectorXd alpha,beta;
+    step = VectorXd::Zero(fsize);
     while (true)
     {
         cal_Dw(Data,weight,dw);
@@ -236,7 +243,6 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){
         VectorXd res = VectorXd::Zero(fsize);
         cal_Dtw(Data,tmp,res);
         grad = weight + C*res;
-        step = VectorXd::Zero(fsize);
         // Solve
         cg_solve(Data,rank,corr,alpha,grad,step);
         // do line search
-- 
cgit v1.2.3-70-g09d2