diff options
| -rw-r--r-- | model/ranksvm.h | 2 | ||||
| -rw-r--r-- | model/ranksvmtn.cpp | 126 | ||||
| -rw-r--r-- | model/ranksvmtn.h | 2 | ||||
| -rw-r--r-- | tools/dataProvider.h | 81 | ||||
| -rw-r--r-- | tools/fileDataProvider.cpp | 11 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 2 | ||||
| -rw-r--r-- | train.cpp | 4 | 
7 files changed, 144 insertions, 84 deletions
diff --git a/model/ranksvm.h b/model/ranksvm.h index aa5e1ca..58cd7f0 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -25,7 +25,7 @@ protected:      SVMModel model;      int fsize;  public: -    virtual int train(DataList &D)=0; +    virtual int train(RidList &D)=0;      virtual int predict(DataList &D,std::vector<double> &res)=0;      // TODO Not sure how to construct this      //  Possible solution: generate a nxn matrix each row contains the sorted list of ranker result. diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 01d9851..d3ef3af 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -17,37 +17,44 @@ const int cg_maxiter = 30;  const double line_prec=1e-10; // precision  const double line_turb=1e-15; // purturbation -int cal_Hs(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alpha,const vector<int> &A1,const vector<int> &A2,const VectorXd s,VectorXd &Hs) +int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alpha,const VectorXd s,VectorXd &Hs)  {      Hs = VectorXd::Zero(s.rows()); -    VectorXd Ds=D*s; -    VectorXd gamma(D.rows()); -    for (int i=0;i<A1.size();++i) +    VectorXd Ds(D.getSize()); +    for (int i=0;i<D.getSize();++i) +        Ds(i) = D.getVec(i).dot(s); +    VectorXd gamma(D.getSize()); +    for (int i=0;i<D.getSize();)      {          double g=0; -        for (int j = A1[i];j<=A2[i];++j) -            if (corr[rank[j]]<0) -                gamma[rank[j]]=g; +        for (int j = D.getqSize()-1;j>=0;--j) +            if (corr[rank[i+j]]>0) +                gamma[rank[i+j]]=g;              else -                g+=Ds[rank[j]]; +                g+=Ds[rank[i+j]];          g=0; -        for (int j = A2[i];j>=A1[i];--j) -            if (corr[rank[j]]>0) -                gamma[rank[j]]=g; +        i+=D.getqSize(); +        for (int j = D.getqSize();j>0;--j) +            if (corr[rank[i-j]]<0) +                gamma[rank[i-j]]=g;              else -                g+=Ds[rank[j]]; +                g+=Ds[rank[i-j]];      } -    Hs = s + C*(D.transpose()*(alpha.cwiseProduct(Ds) - gamma)); +    VectorXd tmp = alpha.cwiseProduct(Ds)-gamma; +    VectorXd res = 0*s; +    for (int i=0;i<D.getSize();++i) +        res = res + D.getVec(i)*tmp[i]; +    Hs = s + C*res;      return 0;  } -int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alph,const vector<int> &A1,const vector<int> &A2,const VectorXd &b, VectorXd &x) +int cg_solve(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alph,const VectorXd &b, VectorXd &x)  {      double alpha,beta,r_1,r_2;      int iter=0;      VectorXd q;      VectorXd Hs; -    cal_Hs(D,rank,corr,alph,A1,A2,x,Hs); +    cal_Hs(D,rank,corr,alph,x,Hs);      VectorXd res = b - Hs;      VectorXd p = res;      while (1) @@ -65,7 +72,7 @@ int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,cons              beta = r_1 / r_2;              p = res + p * beta;          } -        cal_Hs(D,rank,corr,alph,A1,A2,p,q); +        cal_Hs(D,rank,corr,alph,p,q);          alpha = r_1/p.dot(q);          x=x+p*alpha;          res=res-q*alpha; @@ -98,18 +105,19 @@ void ranksort(int l,int r,vector<int> &rank,VectorXd &ref)          ranksort(i,r,rank,ref);  } -int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,const vector<int> &A1,const vector<int> &A2,vector<int> &rank,VectorXd &yt,VectorXd &alpha,VectorXd &beta) +int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int> &rank,VectorXd &yt,VectorXd &alpha,VectorXd &beta)  {      long n = dw.rows();      yt = dw - corr;      alpha=VectorXd::Zero(n);      beta=VectorXd::Zero(n);      for (int i=0;i<n;++i) rank[i]=i; -    for (int i=0;i<A1.size();++i) +    for (int i=0;i<D.getSize();i+=D.getqSize())      { -        ranksort(A1[i],A2[i],rank,yt); +        int ed=i+D.getqSize()-1; +        ranksort(i,ed,rank,yt);          double a=0,b=0; -        for (int j=A1[i];j<=A2[i];++j) +        for (int j=i;j<=ed;++j)              if (corr[rank[j]]<0)              {                  alpha[rank[j]]=a; @@ -121,7 +129,7 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,const vector<int> &A1                  b+=yt[rank[j]];              }          a=b=0; -        for (int j=A2[i];j>=A1[i];--j) +        for (int j=ed;j>=i;--j)              if (corr[rank[j]]>0)              {                  alpha[rank[j]]=a; @@ -136,28 +144,33 @@ int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,const vector<int> &A1  }  // line search using newton method -int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const vector<int> &A1,const vector<int> &A2,const VectorXd &step,double &t) +int line_search(const VectorXd &w,RidList &D,const VectorXd &corr,const VectorXd &step,double &t)  { -    VectorXd Dd = D*step; -    VectorXd Xd = VectorXd::Zero(A1.size()); +    VectorXd Dd(D.getSize()); +    for (int i=0;i<D.getSize();++i) +        Dd(i) = D.getVec(i).dot(step);      VectorXd alpha,beta,yt;      VectorXd grad;      VectorXd Hs; -    vector<int> rank(D.rows()); +    int n=D.getSize(); +    vector<int> rank(D.getSize());      int iter = 0; -    for (int i=0;i<A1.size();++i) -        Xd(i) = Dd(A1[i])-Dd(A2[i]);      double g,h;      t = 0;      while (1)      {          grad=w+t*step; -        Dd = D*(w + t*step); -        cal_alpha_beta(Dd,corr,A1,A2,rank,yt,alpha,beta); -        grad = grad + C*(D.transpose()*(alpha.cwiseProduct(yt)-beta)); +        for (int i=0;i<D.getSize();++i) +            Dd(i) = D.getVec(i).dot(grad); +        cal_alpha_beta(Dd,corr,D,rank,yt,alpha,beta); +        VectorXd tmp = alpha.cwiseProduct(yt)-beta; +        VectorXd res = 0*grad; +        for (int i=0;i<n;++i) +            res = res + D.getVec(i)*tmp[i]; +        grad = grad + C*res;          g = grad.dot(step); -        cal_Hs(D,rank,corr,alpha,A1,A2,step,Hs); +        cal_Hs(D,rank,corr,alpha,step,Hs);          h = Hs.dot(step);          g=g+line_turb;          h = h+line_turb; @@ -174,17 +187,17 @@ int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const v      return 0;  } -int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const VectorXd &corr,VectorXd &weight){ +int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){      int iter = 0; -    long n=D.rows(); -    LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Query size:" << A1.size(); +    long n=Data.getSize(); +    LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Query size:" << Data.getuSize();      VectorXd grad(fsize);      VectorXd step(fsize);      vector<int> rank(n);      double obj,t; -    VectorXd dw = D*weight; +    VectorXd dw(n);      VectorXd yt;      VectorXd alpha,beta;      while (true) @@ -196,16 +209,21 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A              break;          } -        dw = D*weight; -        cal_alpha_beta(dw,corr,A1,A2,rank,yt,alpha,beta); +        for (int i=0;i<n;++i) +            dw(i) = Data.getVec(i).dot(weight); +        cal_alpha_beta(dw,corr,Data,rank,yt,alpha,beta);          // Generate support vector matrix sv & gradient -        obj = (weight.dot(weight) + C*(alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt)))/2;// -        grad = weight + C*(D.transpose()*(alpha.cwiseProduct(yt)-beta)); +        obj = (weight.dot(weight) + C*(alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt)))/2; +        VectorXd tmp = alpha.cwiseProduct(yt)-beta; +        VectorXd res = 0*weight; +        for (int i=0;i<n;++i) +            res = res + Data.getVec(i)*tmp[i]; +        grad = weight + C*res;          step = grad*0;          // Solve -        cg_solve(D,rank,corr,alpha,A1,A2,grad,step); +        cg_solve(Data,rank,corr,alpha,grad,step);          // do line search -        line_search(weight,D,corr,A1,A2,step,t); +        line_search(weight,Data,corr,step,t);          weight=weight+step*t;          // When dec is small enough          LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj << " Newton decr:"<<step.dot(grad)/2 << " linesearch: "<< -t ; @@ -215,32 +233,14 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A      return 0;  } -int RSVMTN::train(DataList &D){ -    MatrixXd Data(D.getSize(),D.getfSize()); +int RSVMTN::train(RidList &D){      VectorXd corr(D.getSize());      vector<int> A1,A2;      int i,j;      LOG(INFO)<<"Processing input"; -    vector<DataEntry*> &dat = D.getData(); -    for (i=0;i<D.getSize();++i) { -        corr(i)=(dat[i])->rank>0?0.5:-0.5; - -        for (j = 0; j < D.getfSize(); ++j){ -            Data(i, j) = dat[i]->feature(j);} - -    } -    i=j=0; -    while (i<D.getSize()) -    { -        if ((i+1 == D.getSize())|| dat[i]->qid!=dat[i+1]->qid) -        { -            A1.push_back(j); -            A2.push_back(i); -            j = i+1; -        } -        ++i; -    } -    train_orig(fsize,Data,A1,A2,corr,model.weight); +    for (i=0;i<D.getSize();++i) +        corr(i)=D.getL(i)>0?0.5:-0.5; +    train_orig(fsize,D,corr,model.weight);      return 0;  }; diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 4074781..c98e581 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -12,7 +12,7 @@ public:      {          return "TN";      }; -    virtual int train(DataList &D); +    virtual int train(RidList &D);      virtual int predict(DataList &D,std::vector<double> &res);  }; diff --git a/tools/dataProvider.h b/tools/dataProvider.h index 028980e..a3f3d34 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -55,6 +55,70 @@ public:      }  }; +class RidList{ +private: +    int n; +    std::vector<DataEntry*> uniq; +    std::vector<DataEntry*> other; +public: +    void clear(){ +        uniq.clear(); +        other.clear(); +    } +    void setfSize(int fsize){n=fsize;} +    int getfSize(){return n;} +    void addEntry(DataEntry* d){ +        int ext=false; +        for (int i=0;i<uniq.size();++i) +            if (uniq[i]->qid==d->qid) +            { +                ext = true; +                d->rank = i; +            } +        if (ext) +            other.push_back(d); +        else +            uniq.push_back(d); +    } +    int getqSize() +    { +        return (int)(uniq.size()+other.size()-1); +    } +    int getuSize() +    { +        return (int)uniq.size(); +    } +    int getSize() +    { +        return getuSize()*getqSize(); +    } +    Eigen::VectorXd getVec(int x){ +        int a,b,n=getqSize(); +        a=x/n; +        b=x%n; +        Eigen::VectorXd vec = uniq[a]->feature; +        if (b<a) +            vec=vec-uniq[b]->feature; +        else +        if (b<uniq.size()-1) +            vec=vec-uniq[b+1]->feature; +        else +            vec=vec-other[b-uniq.size()+1]->feature; +        return vec.cwiseAbs(); +    }; +    double getL(int x){ +        int a,b,n=(int)(uniq.size()+other.size()-1); +        a=x/n; +        b=x%n; +        if (b<uniq.size()-1) +            return -1; +        else +            if (std::fabs(other[b-uniq.size()+1]->rank - a) < 1e-5) +                return 1; +        return -1; +    }; +}; +  class DataProvider  //Virtual base class for data input  {  protected: @@ -63,22 +127,7 @@ public:      DataProvider():eof(false){};      bool EOFile(){return eof;} -    void getAllDataSet(DataList &out){\ -        out.clear(); -        DataList buf; -        while (!EOFile()) -        { -            getDataSet(buf); -            // won't work as data are discarded with every call to getDataSet -            // out.getData().insert(out.getData().end(),buf.getData().begin(),buf.getData().end()); -            for (int i=0;i<buf.getSize();++i) -            { -                out.addEntry(out.copyEntry(buf.getData()[i])); -            } -            out.setfSize(buf.getfSize()); -        } -        buf.clear(); -    } +    virtual void getAllDataSet(RidList &out) = 0;      virtual int getDataSet(DataList &out) = 0;      virtual int open()=0;      virtual int close()=0; diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp index e9b7f3d..1ff0279 100644 --- a/tools/fileDataProvider.cpp +++ b/tools/fileDataProvider.cpp @@ -170,4 +170,15 @@ void RidFileDP::take(int n,vector<DataEntry*> &a,vector<DataEntry*> &b)              b.push_back(tmp[i]);      scrambler(a);      scrambler(b); +} + +void RidFileDP::getAllDataSet(RidList &out){ +    DataEntry *e; +    if (!read) +        readEntries(); +    out.clear(); +    std::vector<DataEntry*> &dat = d.getData(); +    for (int i=0;i<dat.size();++i) +        out.addEntry(dat[i]); +    out.setfSize(d.getfSize());  }
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 7bea92d..567c8e2 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -17,6 +17,7 @@ private:  public:      FileDP(std::string fn=""):fname(fn){};      virtual int getDataSet(DataList &out); +    virtual void getAllDataSet(RidList &out){ LOG(FATAL)<<"getAllDataSet for normal FileDP not implemented";};      virtual int open(){fin.open(fname); eof=false;return 0;};      virtual int close(){fin.close();return 0;};  }; @@ -37,6 +38,7 @@ public:      void readEntries();      int getfSize() { if(!read) readEntries(); return d.getfSize();};      int getpSize(); +    virtual void getAllDataSet(RidList &out);      virtual int getDataSet(DataList &out);      virtual int open(){fin.open(fname); eof=false;return 0;};      virtual int close(){fin.close(); d.clear();return 0;}; @@ -22,14 +22,12 @@ int train(DataProvider &dp) {      rsvm = RSVM::loadModel(vm["model"].as<string>());      dp.open(); -    DataList D; +    RidList D;      LOG(INFO)<<"Training started";      dp.getAllDataSet(D);      LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";      rsvm->train(D); -    vector<double> L; -    rsvm->predict(D,L);      LOG(INFO)<<"Training finished,saving model";  | 
