diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2015-05-19 00:11:23 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2015-05-19 00:11:23 +0800 | 
| commit | 615c70adb468de9b653930412ba87937f8e7f76b (patch) | |
| tree | 8a10083339f81f9a7eb77c5c2114c4543ee78139 /model | |
| parent | accb2b987662db4c8ced69aab78c95b622030576 (diff) | |
| download | ranksvm-615c70adb468de9b653930412ba87937f8e7f76b.tar.gz ranksvm-615c70adb468de9b653930412ba87937f8e7f76b.tar.bz2 ranksvm-615c70adb468de9b653930412ba87937f8e7f76b.zip  | |
openmp cont
Diffstat (limited to 'model')
| -rw-r--r-- | model/ranksvmtn.cpp | 15 | 
1 files changed, 7 insertions, 8 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 8448a8a..426de3a 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -66,9 +66,8 @@ int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorX                  g+=Ds[rank[i-j]];      }      VectorXd tmp = alpha.cwiseProduct(Ds)-gamma; -    VectorXd res = 0*s; -    for (int i=0;i<n;++i) -        res = res + D.getVec(i)*tmp[i]; +    VectorXd res = VectorXd::Zero(D.getSize()); +    cal_Dtw(D,tmp,res);      Hs = s + C*res;      return 0;  } @@ -192,9 +191,8 @@ int line_search(const VectorXd &w,RidList &D,const VectorXd &corr,const VectorXd          cal_Dw(D,grad,Dd);          cal_alpha_beta(Dd,corr,D,rank,yt,alpha,beta);          VectorXd tmp = alpha.cwiseProduct(yt)-beta; -        VectorXd res = 0*grad; -        for (int i=0;i<n;++i) -            res = res + D.getVec(i)*tmp[i]; +        VectorXd res = VectorXd::Zero(D.getfSize()); +        cal_Dtw(D,tmp,res);          grad = grad + C*res;          g = grad.dot(step);          cal_Hs(D,rank,corr,alpha,step,Hs); @@ -238,7 +236,7 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){          VectorXd res = VectorXd::Zero(fsize);          cal_Dtw(Data,tmp,res);          grad = weight + C*res; -        step = grad*0; +        step = VectorXd::Zero(fsize);          // Solve          cg_solve(Data,rank,corr,alpha,grad,step);          // do line search @@ -272,7 +270,8 @@ int RSVMTN::train(RidList &D){  int RSVMTN::predict(RidList &D, vector<double> &res){      res.clear(); -    for (int i=0;i<D.getSize();++i) +    int n = D.getSize(); +    for (int i=0;i<n;++i)          res.push_back(D.getVec(i).dot(model.weight));      return 0;  };
\ No newline at end of file  | 
