diff options
Diffstat (limited to 'model/ranksvmtn.cpp')
-rw-r--r-- | model/ranksvmtn.cpp | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 8448a8a..426de3a 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -66,9 +66,8 @@ int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorX g+=Ds[rank[i-j]]; } VectorXd tmp = alpha.cwiseProduct(Ds)-gamma; - VectorXd res = 0*s; - for (int i=0;i<n;++i) - res = res + D.getVec(i)*tmp[i]; + VectorXd res = VectorXd::Zero(D.getSize()); + cal_Dtw(D,tmp,res); Hs = s + C*res; return 0; } @@ -192,9 +191,8 @@ int line_search(const VectorXd &w,RidList &D,const VectorXd &corr,const VectorXd cal_Dw(D,grad,Dd); cal_alpha_beta(Dd,corr,D,rank,yt,alpha,beta); VectorXd tmp = alpha.cwiseProduct(yt)-beta; - VectorXd res = 0*grad; - for (int i=0;i<n;++i) - res = res + D.getVec(i)*tmp[i]; + VectorXd res = VectorXd::Zero(D.getfSize()); + cal_Dtw(D,tmp,res); grad = grad + C*res; g = grad.dot(step); cal_Hs(D,rank,corr,alpha,step,Hs); @@ -238,7 +236,7 @@ int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){ VectorXd res = VectorXd::Zero(fsize); cal_Dtw(Data,tmp,res); grad = weight + C*res; - step = grad*0; + step = VectorXd::Zero(fsize); // Solve cg_solve(Data,rank,corr,alpha,grad,step); // do line search @@ -272,7 +270,8 @@ int RSVMTN::train(RidList &D){ int RSVMTN::predict(RidList &D, vector<double> &res){ res.clear(); - for (int i=0;i<D.getSize();++i) + int n = D.getSize(); + for (int i=0;i<n;++i) res.push_back(D.getVec(i).dot(model.weight)); return 0; };
\ No newline at end of file |