summaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp44
1 files changed, 24 insertions, 20 deletions
diff --git a/main.cpp b/main.cpp
index eeb6b99..0c71b07 100644
--- a/main.cpp
+++ b/main.cpp
@@ -5,8 +5,7 @@
#include "tools/easylogging++.h"
#include "model/ranksvmtn.h"
#include "tools/fileDataProvider.h"
-#include "tools/matrixIO.h"
-#include <fstream>
+#include "model/rankaccu.h"
INITIALIZE_EASYLOGGINGPP
@@ -28,6 +27,10 @@ int train() {
dp.getDataSet(D);
LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
rsvm->train(D);
+ std::vector<double> L;
+ rsvm->predict(D,L);
+
+ rank_accu(D,L);
LOG(INFO)<<"Training finished,saving model";
@@ -54,24 +57,31 @@ int predict() {
rsvm->predict(D,L);
}
- LOG(INFO)<<"Finished,saving prediction";
- std::ofstream fout(vm["output"].as<std::string>().c_str());
+ if (vm.count("validate"))
+ {
+ rank_accu(D,L);
+ }
- for (int i=0; i<L.size();++i)
- fout<<L[i]<<std::endl;
- fout.close();
+ if (vm.count("output"))
+ {
+ LOG(INFO)<<"Finished,saving prediction";
+ std::ofstream fout(vm["output"].as<std::string>().c_str());
+ for (int i=0; i<L.size();++i)
+ fout<<L[i]<<std::endl;
+ fout.close();
+ }
+ else if (!vm.count("validate"))
+ {
+ LOG(INFO)<<"Finished";
+ for (int i=0; i<L.size();++i)
+ std::cout<<L[i]<<std::endl;
+ }
dp.close();
delete rsvm;
return 0;
}
-int validate()
-{
- LOG(FATAL)<<"Not Implemented";
- return 0;
-}
-
int main(int argc, char **argv) {
// Defining program options
po::options_description desc("Allowed options");
@@ -95,15 +105,9 @@ int main(int argc, char **argv) {
}
if (vm.count("train")) {
- LOG(INFO) << "Program option: training";
train();
}
- else if (vm.count("validate")) {
- LOG(INFO) << "Program option: validate";
- validate();
- }
- else if (vm.count("predict")) {
- LOG(INFO) << "Program option: predict";
+ else if (vm.count("validate")||vm.count("predict")) {
predict();
}
return 0;