From f2d01e30f459818f0589e06839d38999aecfdc06 Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Sun, 8 Mar 2015 17:47:33 +0800 Subject: scaffolding --- CMakeLists.txt | 2 +- main.cpp | 17 +++++++++++++++-- model/ranksvm.h | 6 ++++-- model/ranksvmtn.cpp | 12 ++++++++++++ model/ranksvmtn.h | 5 ++--- tools/dataProvider.h | 17 +++++++++++++++-- tools/fileDataProvider.h | 15 ++++++++------- 7 files changed, 57 insertions(+), 17 deletions(-) create mode 100644 model/ranksvmtn.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ccf0c8..e1eb353 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,6 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED ) INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) -set(SOURCE_FILES main.cpp ./model/ranksvm.cpp) +set(SOURCE_FILES main.cpp ./model/ranksvm.cpp ./model/ranksvmtn.cpp) add_executable(ranksvm ${SOURCE_FILES}) TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) \ No newline at end of file diff --git a/main.cpp b/main.cpp index 8d5b393..87f9ce5 100644 --- a/main.cpp +++ b/main.cpp @@ -17,7 +17,14 @@ int train() { RSVM *rsvm; rsvm = RSVM::loadModel(vm["model"].as().c_str()); FileDP dp(vm["feature"].as().c_str()); - rsvm->train(dp); + DataSet D; + Labels L; + while (!dp.EOFile()) + { + dp.getDataSet(D); + dp.getLabel(L); + rsvm->train(D,L); + } rsvm->saveModel(vm["output"].as().c_str()); delete rsvm; return 0; @@ -27,7 +34,13 @@ int predict() { RSVM *rsvm; rsvm = RSVM::loadModel(vm["model"].as().c_str()); FileDP dp(vm["feature"].as().c_str()); - rsvm->predict(dp); + DataSet D; + MatrixXd L; + while (!dp.EOFile()) + { + dp.getDataSet(D); + rsvm->predict(D,L); + } delete rsvm; return 0; } diff --git a/model/ranksvm.h b/model/ranksvm.h index e7b7c4a..21fb30b 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -12,8 +12,10 @@ protected: Eigen::VectorXd model; int fsize; public: - virtual int train(DataProvider &D)=0; // Dataprovider will have to provide label - virtual int predict(DataProvider &D)=0; // TODO Not sure how to construct this + virtual int train(DataSet &D, Labels &label)=0; + virtual int predict(DataSet &D, Eigen::MatrixXd &res)=0; + // TODO Not sure how to construct this + // Possible solution: generate a nxn matrix each row contains the sorted list of ranker result. int saveModel(const std::string fname); static RSVM* loadModel(const std::string fname); virtual std::string getName()=0; diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp new file mode 100644 index 0000000..ef8d98c --- /dev/null +++ b/model/ranksvmtn.cpp @@ -0,0 +1,12 @@ +#include "ranksvmtn.h" + +using namespace std; +using namespace Eigen; + +int RSVMTN::train(DataSet &D, Labels &label){ + return 0; +}; + +int RSVMTN::predict(DataSet &D, MatrixXd &res){ + return 0; +}; \ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 4a0fb16..21b03bd 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -12,9 +12,8 @@ public: { return "TN"; }; - - int train(DataProvider &D){return 0;}; - int predict(DataProvider &D){return 0;}; + virtual int train(DataSet &D, Labels &label); + virtual int predict(DataSet &D, Eigen::MatrixXd &res); }; #endif \ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index d9440ce..0e6ed9e 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -3,8 +3,21 @@ #include #include "../tools/easylogging++.h" +#include // TODO decide how to construct training data +// One possible way for training data: +// Matrix composed of an array of feature vectors +// Labels are composed of linked list, such as +// 6,3,4,0,5,0,0 +// => 0->6 | 1->3 | 2->4->5 +// How to compensate for non exhaustive labeling? +// Use -1 to indicate not yet labeled data +// -1s will be excluded from training + +typedef Eigen::MatrixXd DataSet; + +typedef std::vector Labels; class DataProvider //Virtual base class for data input { @@ -19,8 +32,8 @@ public: return attrSize; } - virtual Eigen::MatrixXd* getAttr() = 0; - virtual Eigen::VectorXd* getPref() = 0; + virtual int getDataSet(DataSet &out) = 0; + virtual int getLabel(Labels &out) = 0; virtual int open()=0; virtual bool EOFile()=0; }; diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 7937866..fd8f00d 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -8,17 +8,18 @@ class FileDP:public DataProvider { private: std::string fname; + bool eof; public: - FileDP(){}; - FileDP(std::string fn):fname(fn){}; - virtual Eigen::MatrixXd* getAttr(){ - return new Eigen::MatrixXd(3,3); + FileDP(std::string fn=""):fname(fn),eof(false){}; + void setFname(std::string fn){fname=fn;}; + virtual int getDataSet(DataSet &out){ + return 0; } - virtual Eigen::VectorXd* getPref(){ - return new Eigen::VectorXd(3); + virtual int getLabel(Labels &out){ + return 0; }; virtual int open(){return 0;}; - virtual bool EOFile(){return true;}; + virtual bool EOFile(){return eof;}; }; #endif \ No newline at end of file -- cgit v1.2.3-70-g09d2