summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt4
-rw-r--r--tools/dataProvider.cpp7
-rw-r--r--tools/dataProvider.h11
-rw-r--r--train.cpp5
4 files changed, 25 insertions, 2 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0e356fe..4ae8e51 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -12,8 +12,8 @@ FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED )
INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR})
set(SOURCE_FILES model/ranksvm.cpp model/ranksvmtn.cpp model/rankaccu.cpp tools/fileDataProvider.cpp)
-add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h)
-add_executable(split split.cpp ${SOURCE_FILES})
+add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h tools/dataProvider.cpp)
+add_executable(split split.cpp ${SOURCE_FILES} tools/dataProvider.cpp)
add_dependencies(ranksvm split)
TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} )
TARGET_LINK_LIBRARIES( split ${Boost_LIBRARIES}) \ No newline at end of file
diff --git a/tools/dataProvider.cpp b/tools/dataProvider.cpp
new file mode 100644
index 0000000..deb3b78
--- /dev/null
+++ b/tools/dataProvider.cpp
@@ -0,0 +1,7 @@
+//
+// Created by joe on 5/19/15.
+//
+
+#include "../tools/dataProvider.h"
+
+bool RidList::single = false; \ No newline at end of file
diff --git a/tools/dataProvider.h b/tools/dataProvider.h
index bf47856..348d15c 100644
--- a/tools/dataProvider.h
+++ b/tools/dataProvider.h
@@ -61,6 +61,7 @@ private:
std::vector<DataEntry*> uniq;
std::vector<DataEntry*> other;
public:
+ static bool single;
void clear(){
uniq.clear();
other.clear();
@@ -98,6 +99,8 @@ public:
}
inline int getqSize()
{
+ if (single)
+ return (int)other.size();
return (int)(uniq.size()+other.size()-1);
}
inline int getuSize()
@@ -113,6 +116,8 @@ public:
a=x/n;
b=x%n;
Eigen::VectorXd vec;
+ if (single)
+ return (uniq[a]->feature-other[b]->feature).cwiseAbs();
if (b<a)
vec=uniq[a]->feature-uniq[b]->feature;
else
@@ -126,6 +131,12 @@ public:
int a,b,n=getqSize();
a=x/n;
b=x%n;
+ if (single)
+ {
+ if (std::fabs(other[b]->rank - a) < 1e-5)
+ return 1;
+ return -1;
+ }
if (b<uniq.size()-1)
return -1;
else
diff --git a/train.cpp b/train.cpp
index 0b5b4d4..4b8439a 100644
--- a/train.cpp
+++ b/train.cpp
@@ -105,6 +105,7 @@ int main(int argc, char **argv) {
("predict,P", "use model for prediction")
("cmc,C", "enable cmc auditing")
("debug,d", "show debug messages")
+ ("single,s", "one from a pair")
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
("feature,i", po::value<string>(), "set input feature file")
@@ -126,6 +127,10 @@ int main(int argc, char **argv) {
el::Loggers::reconfigureLogger("default", defaultConf);
mainFunc mainf;
+ if (vm.count("single"))
+ RidList::single=true;
+ else
+ RidList::single=false;
if (vm.count("train")) {
if (vm.count("c")) {
C=vm["c"].as<double>();