diff options
Diffstat (limited to 'model')
-rw-r--r-- | model/ranksvm.cpp | 18 | ||||
-rw-r--r-- | model/ranksvm.h | 10 | ||||
-rw-r--r-- | model/ranksvmtn.h | 4 |
3 files changed, 17 insertions, 15 deletions
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index b15d2ef..58a097a 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -1,21 +1,23 @@ #include"ranksvm.h" -#include"ranksvmtron.h" +#include"ranksvmtn.h" +#include"../tools/matrixIO.h" #include<iostream> #include<fstream> #include<string> using namespace Eigen; +using namespace std; -int RSVM::saveModel(string fname){ +int RSVM::saveModel(const string fname){ - std::ofstream fout(fname); + std::ofstream fout(fname.c_str()); fout<<this->getName()<<endl; fout<<this->model; return 0; } -static RSVM* RSVM::loadModel(string fname){ - std::ifstream fin(fname); +RSVM* RSVM::loadModel(const string fname){ + std::ifstream fin(fname.c_str()); std::string type; int fsize; fin>>type; @@ -25,17 +27,17 @@ static RSVM* RSVM::loadModel(string fname){ // TODO multiplex type if (type=="TN") - RSVM = new RSVMTN(); + rsvm = new RSVMTN(); rsvm->fsize=fsize; VectorXd model; - fin>>model; + Eigen::read_stream(fin, model); rsvm->setModel(model); return rsvm; } -int RSVM::setModel(Eigen::VectorXd model) { +int RSVM::setModel(const Eigen::VectorXd &model) { if (model.cols()!=fsize) LOG(FATAL) << "Feature size mismatch";; this->model=model; diff --git a/model/ranksvm.h b/model/ranksvm.h index 8993b87..fad790d 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -12,14 +12,14 @@ protected: Eigen::VectorXd model; int fsize; public: - virtual int train(DataProvider D)=0; - virtual int predict(DataProvider D); - int saveModel(std::string fname); - static RSVM loadModel(std::string fname); + virtual int train(DataProvider &D)=0; + virtual int predict(DataProvider &D)=0; + int saveModel(const std::string fname); + static RSVM* loadModel(const std::string fname); virtual std::string getName()=0; Eigen::MatrixXd getModel(){ return model;}; - void setModel(Eigen::VectorXd model); + int setModel(const Eigen::VectorXd &model); }; #endif
\ No newline at end of file diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h index 2a8f524..4a0fb16 100644 --- a/model/ranksvmtn.h +++ b/model/ranksvmtn.h @@ -13,8 +13,8 @@ public: return "TN"; }; - int train(DataProvider D){return 0;}; - int predict(DataProvider D){return 0;}; + int train(DataProvider &D){return 0;}; + int predict(DataProvider &D){return 0;}; }; #endif
\ No newline at end of file |