summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-04-26 22:53:13 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-04-26 22:53:13 +0800
commita8a7bf5f9b9a1eb0d41f839afd06cc532356a902 (patch)
treea70dbb38d76d30a84a8298096a40830ba7b528a7
parentbe756a55086b5a8f62b979b456475c86ec2cfb61 (diff)
downloadranksvm-a8a7bf5f9b9a1eb0d41f839afd06cc532356a902.tar.gz
ranksvm-a8a7bf5f9b9a1eb0d41f839afd06cc532356a902.tar.bz2
ranksvm-a8a7bf5f9b9a1eb0d41f839afd06cc532356a902.zip
getAllData
-rw-r--r--CMakeLists.txt2
-rw-r--r--main.cpp40
-rw-r--r--tools/dataProvider.h10
-rw-r--r--tools/reidFDataProvider.h8
4 files changed, 39 insertions, 21 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3d8a4e3..18ef86f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -12,5 +12,5 @@ FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED )
INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR})
set(SOURCE_FILES main.cpp ./model/ranksvm.cpp ./model/ranksvmtn.cpp ./model/rankaccu.cpp)
-add_executable(ranksvm ${SOURCE_FILES} model/rankaccu.h model/rankaccu.cpp)
+add_executable(ranksvm ${SOURCE_FILES} model/rankaccu.h model/rankaccu.cpp tools/reidFDataProvider.h)
TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) \ No newline at end of file
diff --git a/main.cpp b/main.cpp
index d5c6f69..cf1dd14 100644
--- a/main.cpp
+++ b/main.cpp
@@ -24,7 +24,7 @@ int train(DataProvider &dp) {
DataList D;
LOG(INFO)<<"Training started";
- dp.getDataSet(D);
+ dp.getAllData(D);
LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
rsvm->train(D);
std::vector<double> L;
@@ -49,30 +49,32 @@ int predict(DataProvider &dp) {
std::vector<double> L;
LOG(INFO)<<"Prediction started";
- dp.getDataSet(D);
- LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
- rsvm->predict(D,L);
+ std::ofstream fout;
+ if (vm.count("output"))
+ fout.open(vm["output"].as<std::string>().c_str());
- if (vm.count("validate"))
+ while (!dp.EOFile())
{
- rank_accu(D,L);
+ dp.getDataSet(D);
+ LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
+ rsvm->predict(D,L);
+
+ if (vm.count("validate"))
+ {
+ rank_accu(D,L);
+ }
+
+ if (vm.count("output"))
+ for (int i=0; i<L.size();++i)
+ fout<<L[i]<<std::endl;
+ else if (!vm.count("validate"))
+ for (int i=0; i<L.size();++i)
+ std::cout<<L[i]<<std::endl;
}
+ LOG(INFO)<<"Finished";
if (vm.count("output"))
- {
- LOG(INFO)<<"Finished,saving prediction";
- std::ofstream fout(vm["output"].as<std::string>().c_str());
-
- for (int i=0; i<L.size();++i)
- fout<<L[i]<<std::endl;
fout.close();
- }
- else if (!vm.count("validate"))
- {
- LOG(INFO)<<"Finished";
- for (int i=0; i<L.size();++i)
- std::cout<<L[i]<<std::endl;
- }
dp.close();
delete rsvm;
return 0;
diff --git a/tools/dataProvider.h b/tools/dataProvider.h
index 2c3169a..64bfa2d 100644
--- a/tools/dataProvider.h
+++ b/tools/dataProvider.h
@@ -51,7 +51,15 @@ public:
DataProvider():eof(false){};
bool EOFile(){return eof;}
-
+ int getAllData(DataList &out){\
+ out.clear();
+ DataList buf;
+ while (!EOFile())
+ {
+ getDataSet(buf);
+ out.getData().insert(out.getData().end(),buf.getData().begin(),buf.getData().end());
+ }
+ }
virtual int getDataSet(DataList &out) = 0;
virtual int open()=0;
virtual int close()=0;
diff --git a/tools/reidFDataProvider.h b/tools/reidFDataProvider.h
new file mode 100644
index 0000000..9fa833a
--- /dev/null
+++ b/tools/reidFDataProvider.h
@@ -0,0 +1,8 @@
+//
+// Created by joe on 4/26/15.
+//
+
+#ifndef RANKSVM_REIDFDATAPROVIDER_H
+#define RANKSVM_REIDFDATAPROVIDER_H
+
+#endif //RANKSVM_REIDFDATAPROVIDER_H