diff options
-rw-r--r-- | model/ranksvm.cpp | 17 | ||||
-rw-r--r-- | model/ranksvm.h | 11 |
2 files changed, 18 insertions, 10 deletions
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 628ef37..dc2ad9f 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -13,7 +13,8 @@ int RSVM::saveModel(const string fname){ std::ofstream fout(fname.c_str()); fout<<this->getName()<<endl; fout<<this->fsize<<endl; - Eigen::write_stream(fout, this->model); + Eigen::write_stream(fout, this->model.weight); + fout<<this->model.beta<<endl; return 0; } @@ -30,16 +31,18 @@ RSVM* RSVM::loadModel(const string fname){ rsvm = new RSVMTN(); rsvm->fsize=fsize; - VectorXd model; - Eigen::read_stream(fin, model); + SVMModel model; + Eigen::read_stream(fin, model.weight); + fin>>model.beta; rsvm->setModel(model); return rsvm; } -int RSVM::setModel(const Labels &model) { - if (model.rows()!=fsize) - LOG(FATAL) << "Feature size mismatch: "<<fsize<<" "<<model.cols(); - this->model=model; +int RSVM::setModel(const SVMModel &model) { + if (model.weight.cols()!=fsize) + LOG(FATAL) << "Feature size mismatch: "<<fsize<<" "<<model.weight.cols(); + this->model.weight=model.weight; + this->model.beta=model.beta; return 0; }
\ No newline at end of file diff --git a/model/ranksvm.h b/model/ranksvm.h index 5217b56..edb80f6 100644 --- a/model/ranksvm.h +++ b/model/ranksvm.h @@ -6,10 +6,15 @@ #include"../tools/dataProvider.h" #include "../tools/easylogging++.h" +typedef struct SVMModel{ + Eigen::VectorXd weight; + double beta; +} SVMModel; + class RSVM //Virtual base class for all RSVM operations { protected: - Eigen::VectorXd model; + SVMModel model; int fsize; public: virtual int train(DataSet &D, Labels &label)=0; @@ -19,9 +24,9 @@ public: int saveModel(const std::string fname); static RSVM* loadModel(const std::string fname); virtual std::string getName()=0; - Eigen::MatrixXd getModel(){ + SVMModel getModel(){ return model;}; - int setModel(const Labels &model); + int setModel(const SVMModel &model); }; #endif
\ No newline at end of file |