diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-03-08 22:25:52 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-03-08 22:25:52 +0800 |
commit | 3d204f5fe4614624ca342090feecbfe4df188d9d (patch) | |
tree | d6f5e8871bb43dfc550de562d1d2811bd1023445 | |
parent | f2d01e30f459818f0589e06839d38999aecfdc06 (diff) | |
download | ranksvm-3d204f5fe4614624ca342090feecbfe4df188d9d.tar.gz ranksvm-3d204f5fe4614624ca342090feecbfe4df188d9d.tar.bz2 ranksvm-3d204f5fe4614624ca342090feecbfe4df188d9d.zip |
scaffolding, tested
-rw-r--r-- | main.cpp | 20 | ||||
-rw-r--r-- | model/ranksvm.cpp | 7 | ||||
-rw-r--r-- | tools/dataProvider.h | 5 | ||||
-rw-r--r-- | tools/fileDataProvider.h | 6 |
4 files changed, 27 insertions, 11 deletions
@@ -15,16 +15,24 @@ po::variables_map vm; int train() { RSVM *rsvm; - rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str()); - FileDP dp(vm["feature"].as<std::string>().c_str()); + rsvm = RSVM::loadModel(vm["model"].as<std::string>()); + FileDP dp(vm["feature"].as<std::string>()); + + // Generic training operations + dp.open(); DataSet D; Labels L; + LOG(INFO)<<"Training started"; while (!dp.EOFile()) { dp.getDataSet(D); dp.getLabel(L); rsvm->train(D,L); } + + LOG(INFO)<<"Training finished,saving model"; + + rsvm->saveModel(vm["output"].as<std::string>().c_str()); delete rsvm; return 0; @@ -45,6 +53,12 @@ int predict() { return 0; } +int validate() +{ + LOG(FATAL)<<"Not Implemented"; + return 0; +} + int main(int argc, char **argv) { // Defining program options po::options_description desc("Allowed options"); @@ -73,7 +87,7 @@ int main(int argc, char **argv) { } else if (vm.count("validate")) { LOG(INFO) << "Program option: validate"; - predict(); + validate(); } else if (vm.count("predict")) { LOG(INFO) << "Program option: predict"; diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 6294245..060001b 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -12,7 +12,8 @@ int RSVM::saveModel(const string fname){ std::ofstream fout(fname.c_str()); fout<<this->getName()<<endl; - fout<<this->model; + fout<<this->fsize<<endl; + Eigen::write_stream(fout, this->model); return 0; } @@ -37,8 +38,8 @@ RSVM* RSVM::loadModel(const string fname){ } int RSVM::setModel(const Eigen::VectorXd &model) { - if (model.cols()!=fsize) - LOG(FATAL) << "Feature size mismatch"; + if (model.rows()!=fsize) + LOG(FATAL) << "Feature size mismatch: "<<fsize<<" "<<model.cols(); this->model=model; return 0; }
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index 0e6ed9e..ce2bf12 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -24,7 +24,9 @@ class DataProvider //Virtual base class for data input protected: int size; int attrSize; + bool eof; public: + DataProvider():eof(false){}; int getSize(){ return size; } @@ -32,10 +34,11 @@ public: return attrSize; } + bool EOFile(){return eof;}; + virtual int getDataSet(DataSet &out) = 0; virtual int getLabel(Labels &out) = 0; virtual int open()=0; - virtual bool EOFile()=0; }; #endif
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index fd8f00d..8a499ca 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -8,9 +8,8 @@ class FileDP:public DataProvider { private: std::string fname; - bool eof; public: - FileDP(std::string fn=""):fname(fn),eof(false){}; + FileDP(std::string fn=""):fname(fn){}; void setFname(std::string fn){fname=fn;}; virtual int getDataSet(DataSet &out){ return 0; @@ -18,8 +17,7 @@ public: virtual int getLabel(Labels &out){ return 0; }; - virtual int open(){return 0;}; - virtual bool EOFile(){return eof;}; + virtual int open(){eof=true;return 0;}; }; #endif
\ No newline at end of file |