From f6c22b46449fa77f90e319e4b159ccb6c2a5732b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Wed, 11 Mar 2015 00:55:41 +0800 Subject: restructure, changed label type --- main.cpp | 17 +++++++++-------- model/ranksvm.cpp | 2 +- model/ranksvm.h | 4 ++-- model/ranksvmtn.cpp | 2 +- model/ranksvmtn.h | 2 +- tools/dataProvider.h | 3 ++- tools/fileDataProvider.h | 1 + 7 files changed, 17 insertions(+), 14 deletions(-) diff --git a/main.cpp b/main.cpp index c297f30..b479f59 100644 --- a/main.cpp +++ b/main.cpp @@ -5,6 +5,7 @@ #include "tools/easylogging++.h" #include "model/ranksvmtn.h" #include "tools/fileDataProvider.h" +#include "tools/matrixIO.h" INITIALIZE_EASYLOGGINGPP @@ -22,17 +23,15 @@ int train() { dp.open(); DataSet D; Labels L; + LOG(INFO)<<"Training started"; - while (!dp.EOFile()) - { - dp.getDataSet(D); - dp.getLabel(L); - rsvm->train(D,L); - } + dp.getDataSet(D); + dp.getLabel(L); + rsvm->train(D,L); LOG(INFO)<<"Training finished,saving model"; - + dp.close(); rsvm->saveModel(vm["output"].as().c_str()); delete rsvm; return 0; @@ -43,12 +42,14 @@ int predict() { rsvm = RSVM::loadModel(vm["model"].as().c_str()); FileDP dp(vm["feature"].as().c_str()); DataSet D; - MatrixXd L; + Labels L; while (!dp.EOFile()) { dp.getDataSet(D); rsvm->predict(D,L); } + + Eigen::write_stream(std::cout, L); delete rsvm; return 0; } diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 060001b..628ef37 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -37,7 +37,7 @@ RSVM* RSVM::loadModel(const string fname){ return rsvm; } -int RSVM::setModel(const Eigen::VectorXd &model) { +int RSVM::setModel(const Labels &model) { if (model.rows()!=fsize) LOG(FATAL) << "Feature size mismatch: "<model=model; diff --git a/model/ranksvm.h b/model/ranksvm.h index 21fb30b..5217b56 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -13,7 +13,7 @@ protected: int fsize; public: virtual int train(DataSet &D, Labels &label)=0; - virtual int predict(DataSet &D, Eigen::MatrixXd &res)=0; + virtual int predict(DataSet &D, Labels &res)=0; // TODO Not sure how to construct this // Possible solution: generate a nxn matrix each row contains the sorted list of ranker result. int saveModel(const std::string fname); @@ -21,7 +21,7 @@ public: virtual std::string getName()=0; Eigen::MatrixXd getModel(){ return model;}; - int setModel(const Eigen::VectorXd &model); + int setModel(const Labels &model); }; #endif \ No newline at end of file diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index ef8d98c..746e967 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -7,6 +7,6 @@ int RSVMTN::train(DataSet &D, Labels &label){ return 0; }; -int RSVMTN::predict(DataSet &D, MatrixXd &res){ +int RSVMTN::predict(DataSet &D, Labels &res){ return 0; }; \ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 21b03bd..cdb9796 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -13,7 +13,7 @@ public: return "TN"; }; virtual int train(DataSet &D, Labels &label); - virtual int predict(DataSet &D, Eigen::MatrixXd &res); + virtual int predict(DataSet &D, Labels &res); }; #endif \ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index ce2bf12..d311149 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -17,7 +17,7 @@ typedef Eigen::MatrixXd DataSet; -typedef std::vector Labels; +typedef Eigen::VectorXd Labels; class DataProvider //Virtual base class for data input { @@ -39,6 +39,7 @@ public: virtual int getDataSet(DataSet &out) = 0; virtual int getLabel(Labels &out) = 0; virtual int open()=0; + virtual int close()=0; }; #endif \ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 8a499ca..a4cc252 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -18,6 +18,7 @@ public: return 0; }; virtual int open(){eof=true;return 0;}; + virtual int close(){return 0;}; }; #endif \ No newline at end of file -- cgit v1.2.3-70-g09d2