diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2015-04-12 10:59:08 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2015-04-12 10:59:08 +0800 | 
| commit | 4662779251de3b692c20d4e10980a795f04e7520 (patch) | |
| tree | 9c73cb40236f3c8134f465a5eccbab0837d199df /main.cpp | |
| parent | 6c77acb550288883c25b3c2a769313d5466dda70 (diff) | |
| download | ranksvm-4662779251de3b692c20d4e10980a795f04e7520.tar.gz ranksvm-4662779251de3b692c20d4e10980a795f04e7520.tar.bz2 ranksvm-4662779251de3b692c20d4e10980a795f04e7520.zip | |
validate, nDCG
Diffstat (limited to 'main.cpp')
| -rw-r--r-- | main.cpp | 44 | 
1 files changed, 24 insertions, 20 deletions
| @@ -5,8 +5,7 @@  #include "tools/easylogging++.h"  #include "model/ranksvmtn.h"  #include "tools/fileDataProvider.h" -#include "tools/matrixIO.h" -#include <fstream> +#include "model/rankaccu.h"  INITIALIZE_EASYLOGGINGPP @@ -28,6 +27,10 @@ int train() {      dp.getDataSet(D);      LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";      rsvm->train(D); +    std::vector<double> L; +    rsvm->predict(D,L); + +    rank_accu(D,L);      LOG(INFO)<<"Training finished,saving model"; @@ -54,24 +57,31 @@ int predict() {          rsvm->predict(D,L);      } -    LOG(INFO)<<"Finished,saving prediction"; -    std::ofstream fout(vm["output"].as<std::string>().c_str()); +    if (vm.count("validate")) +    { +        rank_accu(D,L); +    } -    for (int i=0; i<L.size();++i) -        fout<<L[i]<<std::endl; -    fout.close(); +    if (vm.count("output")) +    { +        LOG(INFO)<<"Finished,saving prediction"; +        std::ofstream fout(vm["output"].as<std::string>().c_str()); +        for (int i=0; i<L.size();++i) +            fout<<L[i]<<std::endl; +        fout.close(); +    } +    else if (!vm.count("validate")) +    { +        LOG(INFO)<<"Finished"; +        for (int i=0; i<L.size();++i) +            std::cout<<L[i]<<std::endl; +    }      dp.close();      delete rsvm;      return 0;  } -int validate() -{ -    LOG(FATAL)<<"Not Implemented"; -    return 0; -} -  int main(int argc, char **argv) {      // Defining program options      po::options_description desc("Allowed options"); @@ -95,15 +105,9 @@ int main(int argc, char **argv) {      }      if (vm.count("train")) { -        LOG(INFO) << "Program option: training";          train();      } -    else if (vm.count("validate")) { -        LOG(INFO) << "Program option: validate"; -        validate(); -    } -    else if (vm.count("predict")) { -        LOG(INFO) << "Program option: predict"; +    else if (vm.count("validate")||vm.count("predict")) {          predict();      }      return 0; | 
