diff options
| -rw-r--r-- | model/ranksvmtn.cpp | 44 | 
1 files changed, 25 insertions, 19 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 426de3a..2121008 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -49,21 +49,22 @@ int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorX      VectorXd Ds(n);      cal_Dw(D,s,Ds);      VectorXd gamma(n); -    for (int i=0;i<n;) +    for (int i=0;i<n;i+=q)      {          double g=0; +        // find B, cal A          for (int j = q-1;j>=0;--j)              if (corr[rank[i+j]]>0)                  gamma[rank[i+j]]=g;              else                  g+=Ds[rank[i+j]];          g=0; -        i+=q; -        for (int j = q;j>0;--j) -            if (corr[rank[i-j]]<0) -                gamma[rank[i-j]]=g; +        // find A, cal B +        for (int j = 0;j<q;++j) +            if (corr[rank[i+j]]<0) +                gamma[rank[i+j]]=g;              else -                g+=Ds[rank[i-j]]; +                g+=Ds[rank[i+j]];      }      VectorXd tmp = alpha.cwiseProduct(Ds)-gamma;      VectorXd res = VectorXd::Zero(D.getSize()); @@ -80,30 +81,33 @@ int cg_solve(RidList &D,const vector<int> &rank,const VectorXd &corr,const Vecto      VectorXd Hs;      cal_Hs(D,rank,corr,alph,x,Hs);      VectorXd res = b - Hs; +    // Non preconditioned version      VectorXd p = res; +    r_1 = res.dot(res);      while (1)      { -        // Non preconditioned version +        cal_Hs(D,rank,corr,alph,p,q); +        alpha = r_1/p.dot(q); +        x=x+p*alpha; +        res=res-q*alpha; +        ++iter; +        r_2=r_1;          r_1 = res.dot(res); -        if (iter)          LOG(INFO) << "CG iter "<<iter<<", r:"<<r_1;          if (r_1<cg_prec) // Terminate condition              break; +        if (r_1>r_2) +        { +            LOG(INFO) << "CG forced termination by backward constraint, reverting"; +            x=x-p*alpha; +        }          if (iter >= cg_maxiter)          {              LOG(INFO) << "CG forced termination by maxiter";              break;          } -        if (iter){ -            beta = r_1 / r_2; -            p = res + p * beta; -        } -        cal_Hs(D,rank,corr,alph,p,q); -        alpha = r_1/p.dot(q); -        x=x+p*alpha; -        res=res-q*alpha; -        ++iter; -        r_2=r_1; +        beta = r_1 / r_2; +        p = res + p * beta;      }      return 0;  } @@ -144,6 +148,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int          int ed=i+q-1;          ranksort(i,ed,rank,yt);          double a=0,b=0; +        // find A, cal B          for (int j=i;j<=ed;++j)              if (corr[rank[j]]<0)              { @@ -156,6 +161,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int                  b+=yt[rank[j]];              }          a=b=0; +        // find B, cal A          for (int j=ed;j>=i;--j)              if (corr[rank[j]]>0)              { @@ -226,6 +232,7 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){      VectorXd dw(n);      VectorXd yt;      VectorXd alpha,beta; +    step = VectorXd::Zero(fsize);      while (true)      {          cal_Dw(Data,weight,dw); @@ -236,7 +243,6 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){          VectorXd res = VectorXd::Zero(fsize);          cal_Dtw(Data,tmp,res);          grad = weight + C*res; -        step = VectorXd::Zero(fsize);          // Solve          cg_solve(Data,rank,corr,alpha,grad,step);          // do line search  | 
