summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.cpp20
-rw-r--r--model/ranksvm.cpp7
-rw-r--r--tools/dataProvider.h5
-rw-r--r--tools/fileDataProvider.h6
4 files changed, 27 insertions, 11 deletions
diff --git a/main.cpp b/main.cpp
index 87f9ce5..c297f30 100644
--- a/main.cpp
+++ b/main.cpp
@@ -15,16 +15,24 @@ po::variables_map vm;
int train() {
RSVM *rsvm;
- rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str());
- FileDP dp(vm["feature"].as<std::string>().c_str());
+ rsvm = RSVM::loadModel(vm["model"].as<std::string>());
+ FileDP dp(vm["feature"].as<std::string>());
+
+ // Generic training operations
+ dp.open();
DataSet D;
Labels L;
+ LOG(INFO)<<"Training started";
while (!dp.EOFile())
{
dp.getDataSet(D);
dp.getLabel(L);
rsvm->train(D,L);
}
+
+ LOG(INFO)<<"Training finished,saving model";
+
+
rsvm->saveModel(vm["output"].as<std::string>().c_str());
delete rsvm;
return 0;
@@ -45,6 +53,12 @@ int predict() {
return 0;
}
+int validate()
+{
+ LOG(FATAL)<<"Not Implemented";
+ return 0;
+}
+
int main(int argc, char **argv) {
// Defining program options
po::options_description desc("Allowed options");
@@ -73,7 +87,7 @@ int main(int argc, char **argv) {
}
else if (vm.count("validate")) {
LOG(INFO) << "Program option: validate";
- predict();
+ validate();
}
else if (vm.count("predict")) {
LOG(INFO) << "Program option: predict";
diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp
index 6294245..060001b 100644
--- a/model/ranksvm.cpp
+++ b/model/ranksvm.cpp
@@ -12,7 +12,8 @@ int RSVM::saveModel(const string fname){
std::ofstream fout(fname.c_str());
fout<<this->getName()<<endl;
- fout<<this->model;
+ fout<<this->fsize<<endl;
+ Eigen::write_stream(fout, this->model);
return 0;
}
@@ -37,8 +38,8 @@ RSVM* RSVM::loadModel(const string fname){
}
int RSVM::setModel(const Eigen::VectorXd &model) {
- if (model.cols()!=fsize)
- LOG(FATAL) << "Feature size mismatch";
+ if (model.rows()!=fsize)
+ LOG(FATAL) << "Feature size mismatch: "<<fsize<<" "<<model.cols();
this->model=model;
return 0;
} \ No newline at end of file
diff --git a/tools/dataProvider.h b/tools/dataProvider.h
index 0e6ed9e..ce2bf12 100644
--- a/tools/dataProvider.h
+++ b/tools/dataProvider.h
@@ -24,7 +24,9 @@ class DataProvider //Virtual base class for data input
protected:
int size;
int attrSize;
+ bool eof;
public:
+ DataProvider():eof(false){};
int getSize(){
return size;
}
@@ -32,10 +34,11 @@ public:
return attrSize;
}
+ bool EOFile(){return eof;};
+
virtual int getDataSet(DataSet &out) = 0;
virtual int getLabel(Labels &out) = 0;
virtual int open()=0;
- virtual bool EOFile()=0;
};
#endif \ No newline at end of file
diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h
index fd8f00d..8a499ca 100644
--- a/tools/fileDataProvider.h
+++ b/tools/fileDataProvider.h
@@ -8,9 +8,8 @@ class FileDP:public DataProvider
{
private:
std::string fname;
- bool eof;
public:
- FileDP(std::string fn=""):fname(fn),eof(false){};
+ FileDP(std::string fn=""):fname(fn){};
void setFname(std::string fn){fname=fn;};
virtual int getDataSet(DataSet &out){
return 0;
@@ -18,8 +17,7 @@ public:
virtual int getLabel(Labels &out){
return 0;
};
- virtual int open(){return 0;};
- virtual bool EOFile(){return eof;};
+ virtual int open(){eof=true;return 0;};
};
#endif \ No newline at end of file