summaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
Diffstat (limited to 'model')
-rw-r--r--model/ranksvm.cpp18
-rw-r--r--model/ranksvm.h10
-rw-r--r--model/ranksvmtn.h4
3 files changed, 17 insertions, 15 deletions
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp
index b15d2ef..58a097a 100644
--- a/model/ranksvm.cpp
+++ b/model/ranksvm.cpp
@@ -1,21 +1,23 @@
#include"ranksvm.h"
-#include"ranksvmtron.h"
+#include"ranksvmtn.h"
+#include"../tools/matrixIO.h"
#include<iostream>
#include<fstream>
#include<string>
using namespace Eigen;
+using namespace std;
-int RSVM::saveModel(string fname){
+int RSVM::saveModel(const string fname){
- std::ofstream fout(fname);
+ std::ofstream fout(fname.c_str());
fout<<this->getName()<<endl;
fout<<this->model;
return 0;
}
-static RSVM* RSVM::loadModel(string fname){
- std::ifstream fin(fname);
+RSVM* RSVM::loadModel(const string fname){
+ std::ifstream fin(fname.c_str());
std::string type;
int fsize;
fin>>type;
@@ -25,17 +27,17 @@ static RSVM* RSVM::loadModel(string fname){
// TODO multiplex type
if (type=="TN")
- RSVM = new RSVMTN();
+ rsvm = new RSVMTN();
rsvm->fsize=fsize;
VectorXd model;
- fin>>model;
+ Eigen::read_stream(fin, model);
rsvm->setModel(model);
return rsvm;
}
-int RSVM::setModel(Eigen::VectorXd model) {
+int RSVM::setModel(const Eigen::VectorXd &model) {
if (model.cols()!=fsize)
LOG(FATAL) << "Feature size mismatch";;
this->model=model;
diff --git a/model/ranksvm.h b/model/ranksvm.h
index 8993b87..fad790d 100644
--- a/model/ranksvm.h
+++ b/model/ranksvm.h
@@ -12,14 +12,14 @@ protected:
Eigen::VectorXd model;
int fsize;
public:
- virtual int train(DataProvider D)=0;
- virtual int predict(DataProvider D);
- int saveModel(std::string fname);
- static RSVM loadModel(std::string fname);
+ virtual int train(DataProvider &D)=0;
+ virtual int predict(DataProvider &D)=0;
+ int saveModel(const std::string fname);
+ static RSVM* loadModel(const std::string fname);
virtual std::string getName()=0;
Eigen::MatrixXd getModel(){
return model;};
- void setModel(Eigen::VectorXd model);
+ int setModel(const Eigen::VectorXd &model);
};
#endif \ No newline at end of file
diff --git a/model/ranksvmtn.h b/model/ranksvmtn.h
index 2a8f524..4a0fb16 100644
--- a/model/ranksvmtn.h
+++ b/model/ranksvmtn.h
@@ -13,8 +13,8 @@ public:
return "TN";
};
- int train(DataProvider D){return 0;};
- int predict(DataProvider D){return 0;};
+ int train(DataProvider &D){return 0;};
+ int predict(DataProvider &D){return 0;};
};
#endif \ No newline at end of file