diff options
| -rw-r--r-- | main.cpp | 20 | ||||
| -rw-r--r-- | model/ranksvm.cpp | 7 | ||||
| -rw-r--r-- | tools/dataProvider.h | 5 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 6 | 
4 files changed, 27 insertions, 11 deletions
@@ -15,16 +15,24 @@ po::variables_map vm;  int train() {      RSVM *rsvm; -    rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str()); -    FileDP dp(vm["feature"].as<std::string>().c_str()); +    rsvm = RSVM::loadModel(vm["model"].as<std::string>()); +    FileDP dp(vm["feature"].as<std::string>()); + +    // Generic training operations +    dp.open();      DataSet D;      Labels L; +    LOG(INFO)<<"Training started";      while (!dp.EOFile())      {          dp.getDataSet(D);          dp.getLabel(L);          rsvm->train(D,L);      } + +    LOG(INFO)<<"Training finished,saving model"; + +      rsvm->saveModel(vm["output"].as<std::string>().c_str());      delete rsvm;      return 0; @@ -45,6 +53,12 @@ int predict() {      return 0;  } +int validate() +{ +    LOG(FATAL)<<"Not Implemented"; +    return 0; +} +  int main(int argc, char **argv) {      // Defining program options      po::options_description desc("Allowed options"); @@ -73,7 +87,7 @@ int main(int argc, char **argv) {      }      else if (vm.count("validate")) {          LOG(INFO) << "Program option: validate"; -        predict(); +        validate();      }      else if (vm.count("predict")) {          LOG(INFO) << "Program option: predict"; diff --git a/model/ranksvm.cpp b/model/ranksvm.cpp index 6294245..060001b 100644 --- a/model/ranksvm.cpp +++ b/model/ranksvm.cpp @@ -12,7 +12,8 @@ int RSVM::saveModel(const string fname){      std::ofstream fout(fname.c_str());      fout<<this->getName()<<endl; -    fout<<this->model; +    fout<<this->fsize<<endl; +    Eigen::write_stream(fout, this->model);      return 0;  } @@ -37,8 +38,8 @@ RSVM* RSVM::loadModel(const string fname){  }  int RSVM::setModel(const Eigen::VectorXd &model) { -    if (model.cols()!=fsize) -        LOG(FATAL) << "Feature size mismatch"; +    if (model.rows()!=fsize) +        LOG(FATAL) << "Feature size mismatch: "<<fsize<<" "<<model.cols();      this->model=model;      return 0;  }
\ No newline at end of file diff --git a/tools/dataProvider.h b/tools/dataProvider.h index 0e6ed9e..ce2bf12 100644 --- a/tools/dataProvider.h +++ b/tools/dataProvider.h @@ -24,7 +24,9 @@ class DataProvider  //Virtual base class for data input  protected:      int size;      int attrSize; +    bool eof;  public: +    DataProvider():eof(false){};      int getSize(){          return size;      } @@ -32,10 +34,11 @@ public:          return attrSize;      } +    bool EOFile(){return eof;}; +      virtual int getDataSet(DataSet &out) = 0;      virtual int getLabel(Labels &out) = 0;      virtual int open()=0; -    virtual bool EOFile()=0;  };  #endif
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index fd8f00d..8a499ca 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -8,9 +8,8 @@ class FileDP:public DataProvider  {  private:      std::string fname; -    bool eof;  public: -    FileDP(std::string fn=""):fname(fn),eof(false){}; +    FileDP(std::string fn=""):fname(fn){};      void setFname(std::string fn){fname=fn;};      virtual int getDataSet(DataSet &out){          return 0; @@ -18,8 +17,7 @@ public:      virtual int getLabel(Labels &out){          return 0;      }; -    virtual int open(){return 0;}; -    virtual bool EOFile(){return eof;}; +    virtual int open(){eof=true;return 0;};  };  #endif
\ No newline at end of file  | 
