summaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp20
1 files changed, 17 insertions, 3 deletions
diff --git a/main.cpp b/main.cpp
index 1cb18b9..e89cfe1 100644
--- a/main.cpp
+++ b/main.cpp
@@ -6,6 +6,7 @@
#include "model/ranksvmtn.h"
#include "tools/fileDataProvider.h"
#include "tools/matrixIO.h"
+#include <fstream>
INITIALIZE_EASYLOGGINGPP
@@ -25,6 +26,7 @@ int train() {
LOG(INFO)<<"Training started";
dp.getDataSet(D);
+ LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
rsvm->train(D);
LOG(INFO)<<"Training finished,saving model";
@@ -39,15 +41,27 @@ int predict() {
RSVM *rsvm;
rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());
FileDP dp(vm["feature"].as<std::string>().c_str());
+
+ dp.open();
DataList D;
- std::list<double> L;
+ std::vector<double> L;
+ LOG(INFO)<<"Prediction started";
+
while (!dp.EOFile())
{
dp.getDataSet(D);
+ LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
rsvm->predict(D,L);
}
- // TODO output Eigen::write_stream(std::cout, L);
+ LOG(INFO)<<"Training finished,saving prediction";
+ std::ofstream fout(vm["output"].as<std::string>().c_str());
+
+ for (int i=0; i<L.size();++i)
+ fout<<L[i]<<std::endl;
+ fout.close();
+
+ dp.close();
delete rsvm;
return 0;
}
@@ -67,7 +81,7 @@ int main(int argc, char **argv) {
("validate,V", "validate model")
("predict,P", "use model for prediction")
("model,m", po::value<std::string>(), "set input model file")
- ("output,o", po::value<std::string>(), "set output model file")
+ ("output,o", po::value<std::string>(), "set output model/prediction file")
("feature,i", po::value<std::string>(), "set input feature file");
// Parsing program options