diff options
-rw-r--r-- | model/ranksvmtn.cpp | 6 | ||||
-rw-r--r-- | tools/dataProvider.h | 23 | ||||
-rw-r--r-- | tools/fileDataProvider.h | 7 |
3 files changed, 26 insertions, 10 deletions
diff --git a/model/ranksvmtn.cpp b/model/ranksvmtn.cpp index 559723c..3c9808b 100644 --- a/model/ranksvmtn.cpp +++ b/model/ranksvmtn.cpp @@ -126,8 +126,8 @@ int RSVMTN::train(DataSet &D, Labels &label){ }; int RSVMTN::predict(DataSet &D, Labels &res){ - res = model.weight * D; - for (int i=0;i<res.cols();++i) - res[i] = (res[i] + model.beta); + //TODO define A + MatrixXd A; + res = A*(D * model.weight); return 0; };
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index d311149..bff1f44 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -4,6 +4,7 @@ #include<Eigen/Dense> #include "../tools/easylogging++.h" #include<vector> +#include<list> // TODO decide how to construct training data // One possible way for training data: @@ -19,6 +20,23 @@ typedef Eigen::MatrixXd DataSet; typedef Eigen::VectorXd Labels; +typedef struct DataEntry{ + int qid; + double rank; + Eigen::VectorXd feature; +} DataEntry; + +class DataList{ +private: + int n; + std::list<DataEntry> data; +public: + int getSize(){return data.size();} + void addEntry(DataEntry d){data.push_front(d);} + void setfSize(int fsize){n=fsize;} + int getfSize(){return n;} +}; + class DataProvider //Virtual base class for data input { protected: @@ -34,10 +52,9 @@ public: return attrSize; } - bool EOFile(){return eof;}; + bool EOFile(){return eof;} - virtual int getDataSet(DataSet &out) = 0; - virtual int getLabel(Labels &out) = 0; + virtual int getDataSet(DataList &out) = 0; virtual int open()=0; virtual int close()=0; }; diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index a4cc252..9ce78e6 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -11,12 +11,11 @@ private: public: FileDP(std::string fn=""):fname(fn){}; void setFname(std::string fn){fname=fn;}; - virtual int getDataSet(DataSet &out){ + virtual int getDataSet(DataList &out){ return 0; } - virtual int getLabel(Labels &out){ - return 0; - }; + int getDataSet(DataSet &D) {return 0;} + int getLabel(Labels &l) {return 0;} virtual int open(){eof=true;return 0;}; virtual int close(){return 0;}; }; |