diff options
-rw-r--r-- | model/rankmisc.h | 20 | ||||
-rw-r--r-- | model/ranksvm.cpp | 2 | ||||
-rw-r--r-- | tools/dataProvider.h | 48 |
3 files changed, 65 insertions, 5 deletions
diff --git a/model/rankmisc.h b/model/rankmisc.h index 2d2011d..33d21a4 100644 --- a/model/rankmisc.h +++ b/model/rankmisc.h @@ -48,4 +48,24 @@ public: }; }; +class RSVMHE:public RSVM +{ +public: + std::string getName() + { + return "HE"; + }; + virtual int train(RidList &D){LOG(FATAL)<< "NOT IMPLEMENTED"; return 0;}; + virtual int predict(RidList &D,std::vector<double> &res){ + res.clear(); + int n = D.getSize(); + for (int i=0;i<n;++i) + { + double r=D.getHell(i); + res.push_back(-r); + } + return 0; + }; +}; + #endif //RANKSVM_RANKMISC_H diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index ed78fbe..20d65e3 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -48,6 +48,8 @@ RSVM* RSVM::loadModel(const string fname){ rsvm = new RSVML1(); if (type=="BH") rsvm = new RSVMBH(); + if (type=="HE") + rsvm = new RSVMHE(); rsvm->fsize=fsize; SVMModel model; diff --git a/tools/dataProvider.h b/tools/dataProvider.h index 65a6b63..eed3079 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -147,8 +147,43 @@ public: oth = &(all[b+1]->feature); } for (int i=0;i<n;++i) - res += sqrt((*id)[i] * (*oth)[i]); - return -log(res); + { + double acc=0; + for (int j=0;j<16;++j,++i) + acc += sqrt((*id)[i] * (*oth)[i]); + res-=log(acc+1e-30); + } + return res; + } + inline double getHell(int x){ + int a,b,q=getqSize(); + a=x/q; + b=x%q; + double res = 0; + Eigen::VectorXd *id,*oth; + if (single) + { + id = &(uniq[a]->feature); + oth = &(other[b]->feature); + } + else { + id = &(all[a]->feature); + if (b<a) + oth = &(all[b]->feature); + else + oth = &(all[b+1]->feature); + } + for (int i=0;i<n;++i) + { + double acc=0; + for (int j=0;j<16;++j,++i) + acc += sqrt((*id)[i] * (*oth)[i]); + res+=sqrt(1-acc); + } + return res; + } + inline double cal(Eigen::VectorXd *id,Eigen::VectorXd *oth,int i) { + return fabs((*id)[i] - (*oth)[i]); } inline Eigen::VectorXd getVec(int x){ int a,b,q=getqSize(); @@ -167,7 +202,10 @@ public: else oth = &(all[b+1]->feature); } - return (*id-*oth).cwiseAbs(); + Eigen::VectorXd res(n); + for (int i=0;i<n;++i) + res(i)=cal(id,oth,i); + return res; }; inline double getVecDot(int x,const Eigen::VectorXd &w) { @@ -189,7 +227,7 @@ public: oth = &(all[b+1]->feature); } for (int i=0;i<n;++i) - res += fabs((*id)[i] - (*oth)[i])*w[i]; + res += cal(id,oth,i)*w[i]; return res; } inline void addVecw(int x,double w,Eigen::VectorXd &X) @@ -211,7 +249,7 @@ public: oth = &(all[b+1]->feature); } for (int i=0;i<n;++i) - X[i] += fabs((*id)[i] - (*oth)[i])*w; + X[i] += cal(id,oth,i)*w; } inline double getL(int x){ int a,b,q=getqSize(); |