summaryrefslogtreecommitdiff
path: root/model/ranksvmtn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'model/ranksvmtn.cpp')
-rw-r--r--model/ranksvmtn.cpp39
1 files changed, 26 insertions, 13 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 105c3fe..d6898d1 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -10,12 +10,33 @@ const int maxiter = 10;
const double prec=1e-4;
const double C=1;
-int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x)
+int computeHs(const MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const VectorXd &pred,const double &C,const VectorXd s,VectorXd &Hs)
+{
+ Hs = VectorXd::Zero(s.rows());
+ VectorXd Ds=D*s;
+ long n = A1.size();
+ VectorXd Xs(n);
+ for (int i=0;i<n;++i)
+ Xs(i) = Ds(A1[i]) - Ds(A2[i]);
+ VectorXd ADXs = VectorXd::Zero(D.rows());
+ for (int i=0;i<n;++i)
+ if (pred(i)>0)
+ {
+ ADXs(A1[i]) = ADXs(A1[i]) + Xs(i);
+ ADXs(A2[i]) = ADXs(A2[i]) - Xs(i);
+ }
+ Hs = s + (C*(D.transpose()*ADXs));
+ return 0;
+}
+
+int cg_solve(const MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const VectorXd &pred, const VectorXd &b,const double &C, VectorXd &x)
{
double alpha,beta,r_1,r_2;
int step=0;
VectorXd q;
- VectorXd res = b - A*x; // TODO Hessian product
+ VectorXd Hs;
+ computeHs(D,A1,A2,pred,C,x,Hs);
+ VectorXd res = b - Hs;
VectorXd p = res;
while (1)
{
@@ -27,8 +48,8 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x)
beta = r_1 / r_2;
p = res + p * beta;
}
-
- q = A*p; // TODO Hessian product
+ computeHs(D,A1,A2,pred,C,p,Hs);
+ q = Hs;
alpha = r_1/p.dot(q);
x=x+p*alpha;
res=res-q*alpha;
@@ -110,16 +131,8 @@ int train_orig(int fsize, MatrixXd &D,vector<int> &A1,vector<int> &A2,VectorXd &
// Generate support vector matrix sv & gradient
objfunc_linear(weight,D,A1,A2,C,pred,grad,obj);
step = grad*0;
- MatrixXd H = MatrixXd::Identity(grad.rows(),grad.rows());
- // Compute Hessian directly
- for (int i=0;i<n;++i)
- if (pred(i)>0) {
- VectorXd v = D.row(A1[i])-D.row(A2[i]);
- H = H + C * (v * v.transpose());
- }
// Solve
- //cout<<obj<<endl;
- cg_solve(H,grad,step);
+ cg_solve(D,A1,A2,pred,grad,C,step);
// do line search
line_search(weight,D,A1,A2,step,pred,t);