summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt9
-rw-r--r--split.cpp76
-rw-r--r--tools/fileDataProvider.cpp173
-rw-r--r--tools/fileDataProvider.h91
-rw-r--r--train.cpp (renamed from main.cpp)12
5 files changed, 273 insertions, 88 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6920572..180456c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -11,6 +11,9 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED )
INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR})
-set(SOURCE_FILES main.cpp ./model/ranksvm.cpp ./model/ranksvmtn.cpp ./model/rankaccu.cpp)
-add_executable(ranksvm ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/easylogging++.h tools/matrixIO.h tools/fileDataProvider.h)
-TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) \ No newline at end of file
+set(SOURCE_FILES model/ranksvm.cpp model/ranksvmtn.cpp model/rankaccu.cpp tools/fileDataProvider.cpp)
+add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h)
+add_executable(split split.cpp ${SOURCE_FILES})
+add_dependencies(ranksvm split)
+TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} )
+TARGET_LINK_LIBRARIES( split ${Boost_LIBRARIES}) \ No newline at end of file
diff --git a/split.cpp b/split.cpp
new file mode 100644
index 0000000..be80545
--- /dev/null
+++ b/split.cpp
@@ -0,0 +1,76 @@
+//
+// Created by joe on 5/13/15.
+//
+
+#include <iostream>
+#include <boost/program_options.hpp>
+#include "tools/dataProvider.h"
+#include "tools/fileDataProvider.h"
+#include <vector>
+#include <fstream>
+
+INITIALIZE_EASYLOGGINGPP
+
+using namespace std;
+namespace po = boost::program_options;
+
+po::variables_map vm;
+
+int outputRid(vector<DataEntry*> a,int fsize,string fname)
+{
+ ofstream fout(fname.c_str());
+ fout<<fsize<<endl;
+ for (int i=0;i<a.size();++i)
+ {
+ fout<< a[i]->qid;
+ for (int j=0;j<fsize;++j)
+ fout<<" "<< a[i]->feature(j);
+ fout<<endl;
+ }
+ fout<<0;
+ fout.close();
+}
+
+int main(int argc, char **argv)
+{
+ el::Configurations defaultConf;
+ defaultConf.setToDefault();
+ // Values are always std::string
+ defaultConf.set(el::Level::Global,el::ConfigurationType::Enabled, "false");
+ // default logger uses default configurations
+ el::Loggers::reconfigureLogger("default", defaultConf);
+ po::options_description desc("Allowed options");
+ desc.add_options()
+ ("help,h", "produce help message")
+ ("query,Q", "Query person count")
+ ("count,c", po::value<int>(), "take number")
+ ("take,a", po::value<string>(), "set output rid file 1(taken)")
+ ("left,b", po::value<string>(), "set output rid file 2(left)")
+ ("input,i", po::value<string>(), "set input Rid file");
+
+ po::store(po::parse_command_line(argc, argv, desc), vm);
+ po::notify(vm);
+ // Print help if necessary
+ if (vm.count("help")) {
+ cout << desc;
+ return 0;
+ }
+
+ if (vm.count("query")){
+ RidFileDP dp(vm["input"].as<string>().c_str());
+ dp.open();
+ cout<<dp.getpSize()<<endl;
+ dp.close();
+ return 0;
+ }
+
+ RidFileDP dp(vm["input"].as<string>().c_str());
+ vector<DataEntry*> a;
+ vector<DataEntry*> b;
+ dp.open();
+ dp.take(vm["count"].as<int>(),a,b);
+ outputRid(a,dp.getfSize(),vm["take"].as<string>());
+ outputRid(b,dp.getfSize(),vm["left"].as<string>());
+ dp.close();
+ return 0;
+} \ No newline at end of file
diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp
new file mode 100644
index 0000000..e9b7f3d
--- /dev/null
+++ b/tools/fileDataProvider.cpp
@@ -0,0 +1,173 @@
+//
+// Created by joe on 5/13/15.
+//
+
+#include "fileDataProvider.h"
+#include <random>
+#include <ctime>
+
+using namespace std;
+
+mt19937 gen;
+
+int FileDP::getDataSet(DataList &out){
+ DataEntry* e;
+ out.clear();
+ int fsize;
+ fin>>fsize;
+ LOG(INFO)<<"Feature size:"<<fsize;
+ out.setfSize(fsize);
+ while (!fin.eof()) {
+ e = new DataEntry;
+ fin>>e->rank;
+ if (e->rank == 0)
+ {
+ delete e;
+ break;
+ }
+ fin>>e->qid;
+ e->feature.resize(fsize);
+ for (int i=0;i<fsize;++i) {
+ fin>>e->feature(i);
+ }
+ out.addEntry(e);
+ }
+ eof=true;
+ return 0;
+}
+
+void RidFileDP::readEntries() {
+ DataEntry *e;
+ int fsize;
+ d.clear();
+ fin >> fsize;
+ LOG(INFO) << "Feature size:" << fsize;
+ d.setfSize(fsize);
+ while (!fin.eof()) {
+ e = new DataEntry;
+ fin >> e->qid;
+ if (e->qid == "0") {
+ delete e;
+ break;
+ }
+ e->feature.resize(fsize);
+ e->rank=-1;
+ for (int i = 0; i < fsize; ++i) {
+ fin >> e->feature(i);
+ }
+ d.addEntry(e);
+ }
+ pos = 0;
+ qid = 1;
+ read = true;
+}
+
+int RidFileDP::getDataSet(DataList &out){
+ DataEntry *e;
+ int fsize;
+ if (!read)
+ readEntries();
+ out.clear();
+ fsize = d.getfSize();
+ out.setfSize(fsize);
+ std::vector<DataEntry*> & dat = d.getData();
+ for (int i=0;i<d.getSize();++i)
+ if (i!=pos)
+ {
+ if (dat[i]->qid == dat[pos]->qid)
+ {
+ e = new DataEntry;
+ e->rank=1;
+ dat[i]->rank=qid;
+ }
+ else
+ {
+ e = new DataEntry;
+ e->rank=-1;
+ }
+ e->feature.resize(d.getfSize());
+ e->qid=dat[pos]->qid;
+ for (int j = 0; j < fsize; ++j) {
+ e->feature(j) = fabs(dat[i]->feature(j) -dat[pos]->feature(j));
+ }
+ out.addEntry(e);
+ }
+ dat[pos]->qid=std::to_string(qid);
+ ++qid;
+ dat[pos]->rank=qid;
+ while (pos<dat.size() && dat[pos]->rank!=-1)
+ ++pos;
+ if (pos==d.getSize())
+ eof = true;
+ return 0;
+}
+
+int RidFileDP::getpSize() {
+ std::vector<string> p;
+ if (!read)
+ readEntries();
+ std::vector<DataEntry*> &dat = d.getData();
+ for (int i=0;i<dat.size();++i)
+ {
+ bool ext=false;
+ for (int j=0;j<p.size();++j)
+ if (p[j] == dat[i]->qid )
+ {
+ ext=true;
+ break;
+ }
+ if (!ext)
+ p.push_back(dat[i]->qid);
+ }
+ return p.size();
+};
+
+void scrambler(vector<DataEntry*> &dat)
+{
+ DataEntry* e;
+ int sz=(int)dat.size();
+ for (int i=0;i<sz;++i)
+ {
+ int pos = (int)(gen()%(sz-i));
+ e=dat[pos];
+ dat[pos] = dat[sz-i-1];
+ dat[sz-i-1] = e;
+ }
+}
+
+void RidFileDP::take(int n,vector<DataEntry*> &a,vector<DataEntry*> &b)
+{
+ gen.seed(time(NULL));
+ DataEntry *e;
+ if (!read)
+ readEntries();
+ vector<DataEntry*> tmp;
+ tmp.reserve(d.getSize());
+ a.clear();
+ b.clear();
+ std::vector<DataEntry*> &dat = d.getData();
+ scrambler(tmp);
+ for (int i=0;i<dat.size();++i)
+ tmp.push_back(dat[i]);
+ int pos = 0;
+ string qid;
+ for (int i=0;i<n;++i)
+ {
+ while (tmp[pos]==NULL)
+ ++pos;
+ qid = tmp[pos]->qid;
+ a.push_back(tmp[pos]);
+ tmp[pos]=NULL;
+ for (int j = pos+1; j< tmp.size();++j)
+ if (tmp[j]!=NULL &&tmp[j]->qid==qid)
+ {
+ a.push_back(tmp[j]);
+ tmp[j]=NULL;
+ }
+ }
+ for (int i=0;i<tmp.size();++i)
+ if (tmp[i]!=NULL)
+ b.push_back(tmp[i]);
+ scrambler(a);
+ scrambler(b);
+} \ No newline at end of file
diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h
index f54a38e..7bea92d 100644
--- a/tools/fileDataProvider.h
+++ b/tools/fileDataProvider.h
@@ -16,31 +16,7 @@ private:
std::ifstream fin;
public:
FileDP(std::string fn=""):fname(fn){};
- virtual int getDataSet(DataList &out){
- DataEntry* e;
- out.clear();
- int fsize;
- fin>>fsize;
- LOG(INFO)<<"Feature size:"<<fsize;
- out.setfSize(fsize);
- while (!fin.eof()) {
- e = new DataEntry;
- fin>>e->rank;
- if (e->rank == 0)
- {
- delete e;
- break;
- }
- fin>>e->qid;
- e->feature.resize(fsize);
- for (int i=0;i<fsize;++i) {
- fin>>e->feature(i);
- }
- out.addEntry(e);
- }
- eof=true;
- return 0;
- }
+ virtual int getDataSet(DataList &out);
virtual int open(){fin.open(fname); eof=false;return 0;};
virtual int close(){fin.close();return 0;};
};
@@ -58,68 +34,13 @@ private:
int qid;
public:
RidFileDP(std::string fn=""):fname(fn){read=false;};
- virtual int getDataSet(DataList &out){
- DataEntry *e;
- int fsize;
- if (!read) {
- d.clear();
- fin >> fsize;
- LOG(INFO) << "Feature size:" << fsize;
- d.setfSize(fsize);
- while (!fin.eof()) {
- e = new DataEntry;
- fin >> e->qid;
- if (e->qid == "0") {
- delete e;
- break;
- }
- e->feature.resize(fsize);
- e->rank=-1;
- for (int i = 0; i < fsize; ++i) {
- fin >> e->feature(i);
- }
- d.addEntry(e);
- }
- pos = 0;
- qid = 1;
- read = true;
- }
- out.clear();
- fsize = d.getfSize();
- out.setfSize(fsize);
- std::vector<DataEntry*> & dat = d.getData();
- for (int i=0;i<d.getSize();++i)
- if (i!=pos)
- {
- if (dat[i]->qid == dat[pos]->qid)
- {
- e = new DataEntry;
- e->rank=1;
- dat[i]->rank=qid;
- }
- else
- {
- e = new DataEntry;
- e->rank=-1;
- }
- e->feature.resize(d.getfSize());
- e->qid=dat[pos]->qid;
- for (int j = 0; j < fsize; ++j) {
- e->feature(j) = fabs(dat[i]->feature(j) -dat[pos]->feature(j));
- }
- out.addEntry(e);
- }
- dat[pos]->qid=std::to_string(qid);
- ++qid;
- dat[pos]->rank=qid;
- while (pos<dat.size() && dat[pos]->rank!=-1)
- ++pos;
- if (pos==d.getSize())
- eof = true;
- return 0;
- }
+ void readEntries();
+ int getfSize() { if(!read) readEntries(); return d.getfSize();};
+ int getpSize();
+ virtual int getDataSet(DataList &out);
virtual int open(){fin.open(fname); eof=false;return 0;};
virtual int close(){fin.close(); d.clear();return 0;};
+ void take(int n,std::vector<DataEntry*> &a,std::vector<DataEntry*> &b);
};
#endif \ No newline at end of file
diff --git a/main.cpp b/train.cpp
index e8666f8..a0c62a9 100644
--- a/main.cpp
+++ b/train.cpp
@@ -95,6 +95,11 @@ int predict(DataProvider &dp) {
}
int main(int argc, char **argv) {
+ el::Configurations defaultConf;
+ defaultConf.setToDefault();
+ // Values are always std::string
+ defaultConf.setGlobally(el::ConfigurationType::Format, "%datetime %level %msg");
+
// Defining program options
po::options_description desc("Allowed options");
desc.add_options()
@@ -103,6 +108,7 @@ int main(int argc, char **argv) {
("validate,V", "validate model")
("predict,P", "use model for prediction")
("cmc,C", "enable cmc auditing")
+ ("debug,d", "show debug messages")
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
("feature,i", po::value<string>(), "set input feature file");
@@ -116,6 +122,12 @@ int main(int argc, char **argv) {
cout << desc;
return 0;
}
+
+ if (!vm.count("debug"))
+ defaultConf.setGlobally(el::ConfigurationType::Enabled, "false");
+ // default logger uses default configurations
+ el::Loggers::reconfigureLogger("default", defaultConf);
+
mainFunc mainf;
if (vm.count("train")) {
mainf = &train;