diff options
Diffstat (limited to 'model')
| -rw-r--r-- | model/ranksvmtn.cpp | 47 | 
1 files changed, 28 insertions, 19 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 959ea7d..105c3fe 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -39,22 +39,30 @@ int cg_solve(const MatrixXd &A, const VectorXd &b, VectorXd &x)  }  // Calculate objfunc gradient & support vectors -int objfunc_linear(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const double C,VectorXd &pred,VectorXd &grad, double &obj) +int objfunc_linear(const VectorXd &w,const MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const double C,VectorXd &pred,VectorXd &grad, double &obj)  {      for (int i=0;i<pred.rows();++i)          pred(i)=pred(i)>0?pred(i):0;      obj = (pred.cwiseProduct(pred)*C).sum()/2 + w.dot(w)/2; -    grad = w - (((pred*C).transpose()*A)*D).transpose(); +    VectorXd pA = VectorXd::Zero(D.rows()); +    for (int i=0;i<A1.size();++i) { +        pA(A1[i]) = pA(A1[i]) + pred(i); +        pA(A2[i]) = pA(A2[i]) - pred(i); +    } +    grad = w - (pA.transpose()*D).transpose();      return 0;  }  // line search using newton method -int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const VectorXd &step,VectorXd &pred,double &t) +int line_search(const VectorXd &w,const MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const VectorXd &step,VectorXd &pred,double &t)  {      double wd=w.dot(step),dd=step.dot(step); +    VectorXd Dd = D*step; +    VectorXd Xd = VectorXd::Zero(A1.size()); +    for (int i=0;i<A1.size();++i) +        Xd(i) = Dd(A1[i])-Dd(A2[i]);      double g,h;      t = 0; -    VectorXd Xd=A*(D*step);      VectorXd pred2;      while (1)      { @@ -69,8 +77,6 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect          g=g+1e-12;          h=h+1e-12;          t=t-g/h; -        cout<<g<<":"<<h<<endl; -        cin.get();          if (g*g/h<1e-10)              break;      } @@ -78,17 +84,20 @@ int line_search(const VectorXd &w,const MatrixXd &D,const MatrixXd &A,const Vect      return 0;  } -int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){ +int train_orig(int fsize, MatrixXd &D,vector<int> &A1,vector<int> &A2,VectorXd &weight){      int iter = 0; -    long n=A.rows(); -    LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Relation size:" << A.rows(); +    long n=A1.size(); +    LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Relation size:" << A1.size();      VectorXd grad(fsize);      VectorXd step(fsize);      VectorXd pred(n);      double obj,t; -    pred=VectorXd::Ones(n) - (A*(D*weight)); +    VectorXd dw = D*weight; +    pred=VectorXd::Zero(n); +    for (int i=0;i<n;++i) +        pred(i) = 1 - dw(A1[i])+dw(A2[i]);      while (true)      {          iter+=1; @@ -99,20 +108,21 @@ int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){          }          // Generate support vector matrix sv & gradient -        objfunc_linear(weight,D,A,C,pred,grad,obj); +        objfunc_linear(weight,D,A1,A2,C,pred,grad,obj);          step = grad*0;          MatrixXd H = MatrixXd::Identity(grad.rows(),grad.rows());          // Compute Hessian directly          for (int i=0;i<n;++i)              if (pred(i)>0) { -                VectorXd v = A.row(i)*D; +                VectorXd v = D.row(A1[i])-D.row(A2[i]);                  H = H + C * (v * v.transpose());              }          // Solve          //cout<<obj<<endl;          cg_solve(H,grad,step);          // do line search -        line_search(weight,D,A,step,pred,t); + +        line_search(weight,D,A1,A2,step,pred,t);          weight=weight+step*t;          int sv=0;          for (int i=0;i<n;++i) @@ -127,7 +137,8 @@ int train_orig(int fsize, MatrixXd &D,MatrixXd &A,VectorXd &weight){  }  int RSVMTN::train(DataList &D){ -    MatrixXd Data(D.getSize(),D.getfSize()),A; +    MatrixXd Data(D.getSize(),D.getfSize()); +    vector<int> A1,A2;      int i,j;      LOG(INFO)<<"Processing input";      for (i=0;i<D.getSize();++i) { @@ -148,7 +159,6 @@ int RSVMTN::train(DataList &D){          }          ++i;      } -    A.resize(cnt,D.getSize());      cnt=i=j=0;      while (i<D.getSize())      { @@ -157,20 +167,19 @@ int RSVMTN::train(DataList &D){              int v1=j,v2;              for (v1=j;(D.getData()[v1]->rank)>0;++v1)                  for (v2=i;(D.getData()[v2]->rank)<0;--v2) { -                    A(cnt,v1) = 1; -                    A(cnt,v2) = -1; +                    A1.push_back(v1); +                    A2.push_back(v2);                      ++cnt;                  }              j = i+1;          }          ++i;      } -    train_orig(fsize,Data,A,model.weight); +    train_orig(fsize,Data,A1,A2,model.weight);      return 0;  };  int RSVMTN::predict(DataList &D, vector<double> &res){ -    //TODO define A      res.clear();      for (int i=0;i<D.getSize();++i)          res.push_back(((D.getData()[i])->feature).dot(model.weight));  | 
