diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-03-11 00:55:41 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-03-11 00:55:41 +0800 |
commit | f6c22b46449fa77f90e319e4b159ccb6c2a5732b (patch) | |
tree | bc38dde03ffa5cb3c2fea9f5fefff0b990de405b | |
parent | 3d204f5fe4614624ca342090feecbfe4df188d9d (diff) | |
download | ranksvm-f6c22b46449fa77f90e319e4b159ccb6c2a5732b.tar.gz ranksvm-f6c22b46449fa77f90e319e4b159ccb6c2a5732b.tar.bz2 ranksvm-f6c22b46449fa77f90e319e4b159ccb6c2a5732b.zip |
restructure, changed label type
-rw-r--r-- | main.cpp | 17 | ||||
-rw-r--r-- | model/ranksvm.cpp | 2 | ||||
-rw-r--r-- | model/ranksvm.h | 4 | ||||
-rw-r--r-- | model/ranksvmtn.cpp | 2 | ||||
-rw-r--r-- | model/ranksvmtn.h | 2 | ||||
-rw-r--r-- | tools/dataProvider.h | 3 | ||||
-rw-r--r-- | tools/fileDataProvider.h | 1 |
7 files changed, 17 insertions, 14 deletions
@@ -5,6 +5,7 @@ #include "tools/easylogging++.h" #include "model/ranksvmtn.h" #include "tools/fileDataProvider.h" +#include "tools/matrixIO.h" INITIALIZE_EASYLOGGINGPP @@ -22,17 +23,15 @@ int train() { dp.open(); DataSet D; Labels L; + LOG(INFO)<<"Training started"; - while (!dp.EOFile()) - { - dp.getDataSet(D); - dp.getLabel(L); - rsvm->train(D,L); - } + dp.getDataSet(D); + dp.getLabel(L); + rsvm->train(D,L); LOG(INFO)<<"Training finished,saving model"; - + dp.close(); rsvm->saveModel(vm["output"].as<std::string>().c_str()); delete rsvm; return 0; @@ -43,12 +42,14 @@ int predict() { rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str()); FileDP dp(vm["feature"].as<std::string>().c_str()); DataSet D; - MatrixXd L; + Labels L; while (!dp.EOFile()) { dp.getDataSet(D); rsvm->predict(D,L); } + + Eigen::write_stream(std::cout, L); delete rsvm; return 0; } diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 060001b..628ef37 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -37,7 +37,7 @@ RSVM* RSVM::loadModel(const string fname){ return rsvm; } -int RSVM::setModel(const Eigen::VectorXd &model) { +int RSVM::setModel(const Labels &model) { if (model.rows()!=fsize) LOG(FATAL) << "Feature size mismatch: "<<fsize<<" "<<model.cols(); this->model=model; diff --git a/model/ranksvm.h b/model/ranksvm.h index 21fb30b..5217b56 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -13,7 +13,7 @@ protected: int fsize; public: virtual int train(DataSet &D, Labels &label)=0; - virtual int predict(DataSet &D, Eigen::MatrixXd &res)=0; + virtual int predict(DataSet &D, Labels &res)=0; // TODO Not sure how to construct this // Possible solution: generate a nxn matrix each row contains the sorted list of ranker result. int saveModel(const std::string fname); @@ -21,7 +21,7 @@ public: virtual std::string getName()=0; Eigen::MatrixXd getModel(){ return model;}; - int setModel(const Eigen::VectorXd &model); + int setModel(const Labels &model); }; #endif
\ No newline at end of file diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index ef8d98c..746e967 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -7,6 +7,6 @@ int RSVMTN::train(DataSet &D, Labels &label){ return 0; }; -int RSVMTN::predict(DataSet &D, MatrixXd &res){ +int RSVMTN::predict(DataSet &D, Labels &res){ return 0; };
\ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 21b03bd..cdb9796 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -13,7 +13,7 @@ public: return "TN"; }; virtual int train(DataSet &D, Labels &label); - virtual int predict(DataSet &D, Eigen::MatrixXd &res); + virtual int predict(DataSet &D, Labels &res); }; #endif
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index ce2bf12..d311149 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -17,7 +17,7 @@ typedef Eigen::MatrixXd DataSet; -typedef std::vector<double> Labels; +typedef Eigen::VectorXd Labels; class DataProvider //Virtual base class for data input { @@ -39,6 +39,7 @@ public: virtual int getDataSet(DataSet &out) = 0; virtual int getLabel(Labels &out) = 0; virtual int open()=0; + virtual int close()=0; }; #endif
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 8a499ca..a4cc252 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -18,6 +18,7 @@ public: return 0; }; virtual int open(){eof=true;return 0;}; + virtual int close(){return 0;}; }; #endif
\ No newline at end of file |