summaryrefslogtreecommitdiff
path: root/model/ranksvmtn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'model/ranksvmtn.cpp')
-rw-r--r--model/ranksvmtn.cpp47
1 files changed, 28 insertions, 19 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
index 959ea7d..105c3fe 100644
--- a/model/ranksvmtn.cpp
+++ b/model/ranksvmtn.cpp
@@ -39,22 +39,30 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x)
}
// Calculate objfunc gradient & support vectors
-int objfunc_linear(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const double C,VectorXd &pred,VectorXd &grad, double &obj)
+int objfunc_linear(const VectorXd &w,const MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const double C,VectorXd &pred,VectorXd &grad, double &obj)
{
for (int i=0;i<pred.rows();++i)
pred(i)=pred(i)>0?pred(i):0;
obj = (pred.cwiseProduct(pred)*C).sum()/2 + w.dot(w)/2;
- grad = w - (((pred*C).transpose()*A)*D).transpose();
+ VectorXd pA = VectorXd::Zero(D.rows());
+ for (int i=0;i<A1.size();++i) {
+ pA(A1[i]) = pA(A1[i]) + pred(i);
+ pA(A2[i]) = pA(A2[i]) - pred(i);
+ }
+ grad = w - (pA.transpose()*D).transpose();
return 0;
}
// line search using newton method
-int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const VectorXd &step,VectorXd &pred,double &t)
+int line_search(const VectorXd &w,const MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const VectorXd &step,VectorXd &pred,double &t)
{
double wd=w.dot(step),dd=step.dot(step);
+ VectorXd Dd = D*step;
+ VectorXd Xd = VectorXd::Zero(A1.size());
+ for (int i=0;i<A1.size();++i)
+ Xd(i) = Dd(A1[i])-Dd(A2[i]);
double g,h;
t = 0;
- VectorXd Xd=A*(D*step);
VectorXd pred2;
while (1)
{
@@ -69,8 +77,6 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect
g=g+1e-12;
h=h+1e-12;
t=t-g/h;
- cout<<g<<":"<<h<<endl;
- cin.get();
if (g*g/h<1e-10)
break;
}
@@ -78,17 +84,20 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect
return 0;
}
-int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){
+int train_orig(int fsize, MatrixXd &D,vector<int> &A1,vector<int> &A2,VectorXd &weight){
int iter = 0;
- long n=A.rows();
- LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Relation size:" << A.rows();
+ long n=A1.size();
+ LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Relation size:" << A1.size();
VectorXd grad(fsize);
VectorXd step(fsize);
VectorXd pred(n);
double obj,t;
- pred=VectorXd::Ones(n) - (A*(D*weight));
+ VectorXd dw = D*weight;
+ pred=VectorXd::Zero(n);
+ for (int i=0;i<n;++i)
+ pred(i) = 1 - dw(A1[i])+dw(A2[i]);
while (true)
{
iter+=1;
@@ -99,20 +108,21 @@ int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){
}
// Generate support vector matrix sv & gradient
- objfunc_linear(weight,D,A,C,pred,grad,obj);
+ objfunc_linear(weight,D,A1,A2,C,pred,grad,obj);
step = grad*0;
MatrixXd H = MatrixXd::Identity(grad.rows(),grad.rows());
// Compute Hessian directly
for (int i=0;i<n;++i)
if (pred(i)>0) {
- VectorXd v = A.row(i)*D;
+ VectorXd v = D.row(A1[i])-D.row(A2[i]);
H = H + C * (v * v.transpose());
}
// Solve
//cout<<obj<<endl;
cg_solve(H,grad,step);
// do line search
- line_search(weight,D,A,step,pred,t);
+
+ line_search(weight,D,A1,A2,step,pred,t);
weight=weight+step*t;
int sv=0;
for (int i=0;i<n;++i)
@@ -127,7 +137,8 @@ int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){
}
int RSVMTN::train(DataList &D){
- MatrixXd Data(D.getSize(),D.getfSize()),A;
+ MatrixXd Data(D.getSize(),D.getfSize());
+ vector<int> A1,A2;
int i,j;
LOG(INFO)<<"Processing input";
for (i=0;i<D.getSize();++i) {
@@ -148,7 +159,6 @@ int RSVMTN::train(DataList &D){
}
++i;
}
- A.resize(cnt,D.getSize());
cnt=i=j=0;
while (i<D.getSize())
{
@@ -157,20 +167,19 @@ int RSVMTN::train(DataList &D){
int v1=j,v2;
for (v1=j;(D.getData()[v1]->rank)>0;++v1)
for (v2=i;(D.getData()[v2]->rank)<0;--v2) {
- A(cnt,v1) = 1;
- A(cnt,v2) = -1;
+ A1.push_back(v1);
+ A2.push_back(v2);
++cnt;
}
j = i+1;
}
++i;
}
- train_orig(fsize,Data,A,model.weight);
+ train_orig(fsize,Data,A1,A2,model.weight);
return 0;
};
int RSVMTN::predict(DataList &D, vector<double> &res){
- //TODO define A
res.clear();
for (int i=0;i<D.getSize();++i)
res.push_back(((D.getData()[i])->feature).dot(model.weight));