diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | model/rankmisc.h | 51 | ||||
-rw-r--r-- | model/ranksvm.cpp | 5 | ||||
-rw-r--r-- | tools/dataProvider.h | 22 |
4 files changed, 79 insertions, 1 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ae8e51..97d548e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,7 @@ FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED ) INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) set(SOURCE_FILES model/ranksvm.cpp model/ranksvmtn.cpp model/rankaccu.cpp tools/fileDataProvider.cpp) -add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h tools/dataProvider.cpp) +add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h tools/dataProvider.cpp model/rankmisc.h) add_executable(split split.cpp ${SOURCE_FILES} tools/dataProvider.cpp) add_dependencies(ranksvm split) TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) diff --git a/model/rankmisc.h b/model/rankmisc.h new file mode 100644 index 0000000..2d2011d --- /dev/null +++ b/model/rankmisc.h @@ -0,0 +1,51 @@ +// +// Created by joe on 5/31/15. +// + +#ifndef RANKSVM_RANKMISC_H +#define RANKSVM_RANKMISC_H + +#include"ranksvm.h" + +class RSVML1:public RSVM +{ +public: + std::string getName() + { + return "L1"; + }; + virtual int train(RidList &D){LOG(FATAL)<< "NOT IMPLEMENTED"; return 0;}; + virtual int predict(RidList &D,std::vector<double> &res){ + res.clear(); + int n = D.getSize(); + Eigen::VectorXd one=Eigen::VectorXd::Ones(fsize); + for (int i=0;i<n;++i) + { + double r=D.getVecDot(i,one); + res.push_back(-r); + } + return 0; + }; +}; + +class RSVMBH:public RSVM +{ +public: + std::string getName() + { + return "BH"; + }; + virtual int train(RidList &D){LOG(FATAL)<< "NOT IMPLEMENTED"; return 0;}; + virtual int predict(RidList &D,std::vector<double> &res){ + res.clear(); + int n = D.getSize(); + for (int i=0;i<n;++i) + { + double r=D.getBha(i); + res.push_back(-r); + } + return 0; + }; +}; + +#endif //RANKSVM_RANKMISC_H diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index d246b2b..ed78fbe 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -1,5 +1,6 @@ #include"ranksvm.h" #include"ranksvmtn.h" +#include"rankmisc.h" #include"../tools/matrixIO.h" #include<iostream> #include<fstream> @@ -43,6 +44,10 @@ RSVM* RSVM::loadModel(const string fname){ if (type=="TN") rsvm = new RSVMTN(); + if (type=="L1") + rsvm = new RSVML1(); + if (type=="BH") + rsvm = new RSVMBH(); rsvm->fsize=fsize; SVMModel model; diff --git a/tools/dataProvider.h b/tools/dataProvider.h index 47946c8..65a6b63 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -128,6 +128,28 @@ public: { return getuSize()*getqSize(); } + inline double getBha(int x){ + int a,b,q=getqSize(); + a=x/q; + b=x%q; + double res = 0; + Eigen::VectorXd *id,*oth; + if (single) + { + id = &(uniq[a]->feature); + oth = &(other[b]->feature); + } + else { + id = &(all[a]->feature); + if (b<a) + oth = &(all[b]->feature); + else + oth = &(all[b+1]->feature); + } + for (int i=0;i<n;++i) + res += sqrt((*id)[i] * (*oth)[i]); + return -log(res); + } inline Eigen::VectorXd getVec(int x){ int a,b,q=getqSize(); a=x/q; |