diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2015-05-11 18:38:39 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2015-05-11 18:38:39 +0800 | 
| commit | c69b39a9f149cc6b5c7270d7d864fb677bc83b34 (patch) | |
| tree | bba8f4b244c5892265709b822751b9be56b4bf47 | |
| parent | 0f05a379b2c5df0b05c23fd91d697464bd250507 (diff) | |
| download | ranksvm-c69b39a9f149cc6b5c7270d7d864fb677bc83b34.tar.gz ranksvm-c69b39a9f149cc6b5c7270d7d864fb677bc83b34.tar.bz2 ranksvm-c69b39a9f149cc6b5c7270d7d864fb677bc83b34.zip | |
RidFile tested
| -rw-r--r-- | main.cpp | 2 | ||||
| -rw-r--r-- | model/rankaccu.cpp | 1 | ||||
| -rw-r--r-- | model/ranksvmtn.cpp | 27 | ||||
| -rw-r--r-- | tools/dataProvider.h | 11 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 6 | 
5 files changed, 36 insertions, 11 deletions
| @@ -25,7 +25,7 @@ int train(DataProvider &dp) {      DataList D;      LOG(INFO)<<"Training started"; -    dp.getAllData(D); +    dp.getAllDataSet(D);      LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";      rsvm->train(D);      vector<double> L; diff --git a/model/rankaccu.cpp b/model/rankaccu.cpp index 910e3e0..73b5d18 100644 --- a/model/rankaccu.cpp +++ b/model/rankaccu.cpp @@ -134,6 +134,7 @@ void rank_CMC(DataList &D,const std::vector<double> pred,CMC & cmc) {              for (int k=j;k<=i;++k)                  if (orig[pred_rank[k]]>0)                  { +                    cout<<pred_rank[k]<<" "<<pred[k+1]<<" "<< k <<" "<< j<<endl;                      cmc.addEntry(k-j);                      break; // account only for the first match;                  } diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index b82ce64..c2ca639 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -6,12 +6,13 @@  using namespace std;  using namespace Eigen; -const double C=1e-2; // Compensating & scaling +const double C=1e-5; // Compensating & scaling  // Main terminating criteria  const int maxiter = 10; // max iteration count -const double prec=1e-4; // precision +const double prec=1e-10; // precision  // conjugate gradient  const double cg_prec=1e-10; // precision +const int cg_maxiter = 30;  // line search  const double line_prec=1e-10; // precision  const double line_turb=1e-15; // purturbation @@ -64,6 +65,11 @@ int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,cons          x=x+p*alpha;          res=res-q*alpha;          ++step; +        if (step > cg_maxiter) +        { +            LOG(INFO) << "CG force terminated by maxiter"; +            break; +        }          r_2=r_1;      }      return 0; @@ -138,6 +144,7 @@ int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const v      VectorXd grad;      VectorXd Hs;      vector<int> rank(D.rows()); +    int iter = 0;      for (int i=0;i<A1.size();++i)          Xd(i) = Dd(A1[i])-Dd(A2[i]); @@ -157,6 +164,12 @@ int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const v          t=t-g/h;          if (g*g/h<line_prec)              break; +        ++iter; +        if (iter > cg_maxiter) +        { +            LOG(INFO) << "line search force terminated by maxiter"; +            break; +        }      }      return 0;  } @@ -179,14 +192,12 @@ int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A          iter+=1;          if (iter> maxiter)          { -            LOG(INFO)<< "Maxiter :"<<maxiter<<" reached"; +            LOG(INFO)<< "Maxiter reached";              break;          }          dw = D*weight; -          cal_alpha_beta(dw,corr,A1,A2,rank,yt,alpha,beta); -          // Generate support vector matrix sv & gradient          obj = (weight.dot(weight) + C*(alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt)))/2;//          grad = weight + C*(D.transpose()*(alpha.cwiseProduct(yt)-beta)); @@ -213,8 +224,10 @@ int RSVMTN::train(DataList &D){      vector<DataEntry*> &dat = D.getData();      for (i=0;i<D.getSize();++i) {          corr(i)=(dat[i])->rank>0?0.5:-0.5; -        for (j = 0; j < D.getfSize(); ++j) -            Data(i, j) = dat[i]->feature(j); + +        for (j = 0; j < D.getfSize(); ++j){ +            Data(i, j) = dat[i]->feature(j);} +      }      i=j=0;      while (i<D.getSize()) diff --git a/tools/dataProvider.h b/tools/dataProvider.h index 1d430e4..028980e 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -40,7 +40,11 @@ public:          DataEntry* dat = new DataEntry;          dat->rank = d->rank;          dat->qid = d->qid; -        dat->feature = d->feature; +        dat->feature.resize(d->feature.rows()); +        for (int i=0;i<d->feature.rows();++i) +        { +            dat->feature(i)=d->feature(i); +        }          return dat;      }      inline std::vector<DataEntry*>& getData(){ @@ -59,7 +63,7 @@ public:      DataProvider():eof(false){};      bool EOFile(){return eof;} -    void getAllData(DataList &out){\ +    void getAllDataSet(DataList &out){\          out.clear();          DataList buf;          while (!EOFile()) @@ -68,9 +72,12 @@ public:              // won't work as data are discarded with every call to getDataSet              // out.getData().insert(out.getData().end(),buf.getData().begin(),buf.getData().end());              for (int i=0;i<buf.getSize();++i) +            {                  out.addEntry(out.copyEntry(buf.getData()[i])); +            }              out.setfSize(buf.getfSize());          } +        buf.clear();      }      virtual int getDataSet(DataList &out) = 0;      virtual int open()=0; diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index c4f6a4a..8ebda20 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -82,9 +82,11 @@ public:              }              pos = 0;              qid = 1; +            read = true;          }          out.clear();          fsize = d.getfSize(); +        out.setfSize(fsize);          std::vector<DataEntry*> & dat = d.getData();          for (int i=0;i<d.getSize();++i)              if (i!=pos) @@ -94,6 +96,7 @@ public:                      e = new DataEntry;                      e->rank=1;                      dat[i]->qid=std::to_string(qid); +                    dat[i]->rank=qid;                  }                  else                  { @@ -103,12 +106,13 @@ public:                  e->feature.resize(d.getfSize());                  e->qid=std::to_string(qid);                  for (int j = 0; j < fsize; ++j) { -                    e->feature(i) = fabs(dat[i]->feature(j) -dat[pos]->feature(j)); +                    e->feature(j) = fabs(dat[i]->feature(j) -dat[pos]->feature(j));                  }                  out.addEntry(e);              }          dat[pos]->qid=std::to_string(qid);          ++qid; +        dat[pos]->rank=qid;          while (pos<dat.size() && dat[pos]->rank!=-1)              ++pos;          if (pos==d.getSize()) | 
