summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-03-08 17:47:33 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-03-08 17:47:33 +0800
commitf2d01e30f459818f0589e06839d38999aecfdc06 (patch)
tree9530ac898c1d4cdecbb5194cbd76288e57f7f7b1
parent22882d7113c13cb1e00c59b54050f16ac1b7cc30 (diff)
downloadranksvm-f2d01e30f459818f0589e06839d38999aecfdc06.tar.gz
ranksvm-f2d01e30f459818f0589e06839d38999aecfdc06.tar.bz2
ranksvm-f2d01e30f459818f0589e06839d38999aecfdc06.zip
scaffolding
-rw-r--r--CMakeLists.txt2
-rw-r--r--main.cpp17
-rw-r--r--model/ranksvm.h6
-rw-r--r--model/ranksvmtn.cpp12
-rw-r--r--model/ranksvmtn.h5
-rw-r--r--tools/dataProvider.h17
-rw-r--r--tools/fileDataProvider.h15
7 files changed, 57 insertions, 17 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6ccf0c8..e1eb353 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -11,6 +11,6 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED )
INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR})
-set(SOURCE_FILES main.cpp ./model/ranksvm.cpp)
+set(SOURCE_FILES main.cpp ./model/ranksvm.cpp ./model/ranksvmtn.cpp)
add_executable(ranksvm ${SOURCE_FILES})
TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) \ No newline at end of file
diff --git a/main.cpp b/main.cpp
index 8d5b393..87f9ce5 100644
--- a/main.cpp
+++ b/main.cpp
@@ -17,7 +17,14 @@ int train() {
RSVM *rsvm;
rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());
FileDP dp(vm["feature"].as<std::string>().c_str());
- rsvm->train(dp);
+ DataSet D;
+ Labels L;
+ while (!dp.EOFile())
+ {
+ dp.getDataSet(D);
+ dp.getLabel(L);
+ rsvm->train(D,L);
+ }
rsvm->saveModel(vm["output"].as<std::string>().c_str());
delete rsvm;
return 0;
@@ -27,7 +34,13 @@ int predict() {
RSVM *rsvm;
rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());
FileDP dp(vm["feature"].as<std::string>().c_str());
- rsvm->predict(dp);
+ DataSet D;
+ MatrixXd L;
+ while (!dp.EOFile())
+ {
+ dp.getDataSet(D);
+ rsvm->predict(D,L);
+ }
delete rsvm;
return 0;
}
diff --git a/model/ranksvm.h b/model/ranksvm.h
index e7b7c4a..21fb30b 100644
--- a/model/ranksvm.h
+++ b/model/ranksvm.h
@@ -12,8 +12,10 @@ protected:
Eigen::VectorXd model;
int fsize;
public:
- virtual int train(DataProvider &D)=0; // Dataprovider will have to provide label
- virtual int predict(DataProvider &D)=0; // TODO Not sure how to construct this
+ virtual int train(DataSet &D, Labels &label)=0;
+ virtual int predict(DataSet &D, Eigen::MatrixXd &res)=0;
+ // TODO Not sure how to construct this
+ // Possible solution: generate a nxn matrix each row contains the sorted list of ranker result.
int saveModel(const std::string fname);
static RSVM* loadModel(const std::string fname);
virtual std::string getName()=0;
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp
new file mode 100644
index 0000000..ef8d98c
--- /dev/null
+++ b/model/ranksvmtn.cpp
@@ -0,0 +1,12 @@
+#include "ranksvmtn.h"
+
+using namespace std;
+using namespace Eigen;
+
+int RSVMTN::train(DataSet &D, Labels &label){
+ return 0;
+};
+
+int RSVMTN::predict(DataSet &D, MatrixXd &res){
+ return 0;
+}; \ No newline at end of file
diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h
index 4a0fb16..21b03bd 100644
--- a/model/ranksvmtn.h
+++ b/model/ranksvmtn.h
@@ -12,9 +12,8 @@ public:
{
return "TN";
};
-
- int train(DataProvider &D){return 0;};
- int predict(DataProvider &D){return 0;};
+ virtual int train(DataSet &D, Labels &label);
+ virtual int predict(DataSet &D, Eigen::MatrixXd &res);
};
#endif \ No newline at end of file
diff --git a/tools/dataProvider.h b/tools/dataProvider.h
index d9440ce..0e6ed9e 100644
--- a/tools/dataProvider.h
+++ b/tools/dataProvider.h
@@ -3,8 +3,21 @@
#include<Eigen/Dense>
#include "../tools/easylogging++.h"
+#include<vector>
// TODO decide how to construct training data
+// One possible way for training data:
+// Matrix composed of an array of feature vectors
+// Labels are composed of linked list, such as
+// 6,3,4,0,5,0,0
+// => 0->6 | 1->3 | 2->4->5
+// How to compensate for non exhaustive labeling?
+// Use -1 to indicate not yet labeled data
+// -1s will be excluded from training
+
+typedef Eigen::MatrixXd DataSet;
+
+typedef std::vector<double> Labels;
class DataProvider //Virtual base class for data input
{
@@ -19,8 +32,8 @@ public:
return attrSize;
}
- virtual Eigen::MatrixXd* getAttr() = 0;
- virtual Eigen::VectorXd* getPref() = 0;
+ virtual int getDataSet(DataSet &out) = 0;
+ virtual int getLabel(Labels &out) = 0;
virtual int open()=0;
virtual bool EOFile()=0;
};
diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h
index 7937866..fd8f00d 100644
--- a/tools/fileDataProvider.h
+++ b/tools/fileDataProvider.h
@@ -8,17 +8,18 @@ class FileDP:public DataProvider
{
private:
std::string fname;
+ bool eof;
public:
- FileDP(){};
- FileDP(std::string fn):fname(fn){};
- virtual Eigen::MatrixXd* getAttr(){
- return new Eigen::MatrixXd(3,3);
+ FileDP(std::string fn=""):fname(fn),eof(false){};
+ void setFname(std::string fn){fname=fn;};
+ virtual int getDataSet(DataSet &out){
+ return 0;
}
- virtual Eigen::VectorXd* getPref(){
- return new Eigen::VectorXd(3);
+ virtual int getLabel(Labels &out){
+ return 0;
};
virtual int open(){return 0;};
- virtual bool EOFile(){return true;};
+ virtual bool EOFile(){return eof;};
};
#endif \ No newline at end of file