diff options
-rw-r--r-- | CMakeLists.txt | 4 | ||||
-rw-r--r-- | tools/dataProvider.cpp | 7 | ||||
-rw-r--r-- | tools/dataProvider.h | 11 | ||||
-rw-r--r-- | train.cpp | 5 |
4 files changed, 25 insertions, 2 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e356fe..4ae8e51 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,8 +12,8 @@ FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED ) INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) set(SOURCE_FILES model/ranksvm.cpp model/ranksvmtn.cpp model/rankaccu.cpp tools/fileDataProvider.cpp) -add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h) -add_executable(split split.cpp ${SOURCE_FILES}) +add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h tools/dataProvider.cpp) +add_executable(split split.cpp ${SOURCE_FILES} tools/dataProvider.cpp) add_dependencies(ranksvm split) TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) TARGET_LINK_LIBRARIES( split ${Boost_LIBRARIES})
\ No newline at end of file diff --git a/tools/dataProvider.cpp b/tools/dataProvider.cpp new file mode 100644 index 0000000..deb3b78 --- /dev/null +++ b/tools/dataProvider.cpp @@ -0,0 +1,7 @@ +// +// Created by joe on 5/19/15. +// + +#include "../tools/dataProvider.h" + +bool RidList::single = false;
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index bf47856..348d15c 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -61,6 +61,7 @@ private: std::vector<DataEntry*> uniq; std::vector<DataEntry*> other; public: + static bool single; void clear(){ uniq.clear(); other.clear(); @@ -98,6 +99,8 @@ public: } inline int getqSize() { + if (single) + return (int)other.size(); return (int)(uniq.size()+other.size()-1); } inline int getuSize() @@ -113,6 +116,8 @@ public: a=x/n; b=x%n; Eigen::VectorXd vec; + if (single) + return (uniq[a]->feature-other[b]->feature).cwiseAbs(); if (b<a) vec=uniq[a]->feature-uniq[b]->feature; else @@ -126,6 +131,12 @@ public: int a,b,n=getqSize(); a=x/n; b=x%n; + if (single) + { + if (std::fabs(other[b]->rank - a) < 1e-5) + return 1; + return -1; + } if (b<uniq.size()-1) return -1; else @@ -105,6 +105,7 @@ int main(int argc, char **argv) { ("predict,P", "use model for prediction") ("cmc,C", "enable cmc auditing") ("debug,d", "show debug messages") + ("single,s", "one from a pair") ("model,m", po::value<string>(), "set input model file") ("output,o", po::value<string>(), "set output model/prediction file") ("feature,i", po::value<string>(), "set input feature file") @@ -126,6 +127,10 @@ int main(int argc, char **argv) { el::Loggers::reconfigureLogger("default", defaultConf); mainFunc mainf; + if (vm.count("single")) + RidList::single=true; + else + RidList::single=false; if (vm.count("train")) { if (vm.count("c")) { C=vm["c"].as<double>(); |