diff options
| -rw-r--r-- | main.cpp | 59 | ||||
| -rw-r--r-- | model/rankaccu.cpp | 31 | ||||
| -rw-r--r-- | model/rankaccu.h | 30 | 
3 files changed, 99 insertions, 21 deletions
@@ -10,6 +10,7 @@  INITIALIZE_EASYLOGGINGPP  using namespace Eigen; +using namespace std;  namespace po = boost::program_options;  po::variables_map vm; @@ -18,7 +19,7 @@ typedef int (*mainFunc)(DataProvider &dp);  int train(DataProvider &dp) {      RSVM *rsvm; -    rsvm = RSVM::loadModel(vm["model"].as<std::string>()); +    rsvm = RSVM::loadModel(vm["model"].as<string>());      dp.open();      DataList D; @@ -27,31 +28,39 @@ int train(DataProvider &dp) {      dp.getAllData(D);      LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";      rsvm->train(D); -    std::vector<double> L; +    vector<double> L;      rsvm->predict(D,L); -    rank_accu(D,L); -      LOG(INFO)<<"Training finished,saving model";      dp.close(); -    rsvm->saveModel(vm["output"].as<std::string>().c_str()); +    rsvm->saveModel(vm["output"].as<string>().c_str());      delete rsvm;      return 0;  }  int predict(DataProvider &dp) {      RSVM *rsvm; -    rsvm = RSVM::loadModel(vm["model"].as<std::string>().c_str()); +    rsvm = RSVM::loadModel(vm["model"].as<string>().c_str());      dp.open();      DataList D; -    std::vector<double> L; +    vector<double> L; +    CMC cmc;      LOG(INFO)<<"Prediction started"; -    std::ofstream fout; +    ofstream fout;      if (vm.count("output")) -        fout.open(vm["output"].as<std::string>().c_str()); +        fout.open(vm["output"].as<string>().c_str()); + +    ostream* ot; + +    if (vm.count("output")) { +        fout.open(vm["output"].as<string>().c_str()); +        ot=&fout; +    } +    else +        ot=&cout;      while (!dp.EOFile())      { @@ -62,17 +71,24 @@ int predict(DataProvider &dp) {          if (vm.count("validate"))          {              rank_accu(D,L); +            if (vm.count("cmc")) +                rank_CMC(D,L,cmc);          } -        if (vm.count("output")) -            for (int i=0; i<L.size();++i) -                fout<<L[i]<<std::endl; -        else if (!vm.count("validate")) +        if (vm.count("output") || !vm.count("validate"))              for (int i=0; i<L.size();++i) -                std::cout<<L[i]<<std::endl; +                *ot<<L[i]<<endl;      }      LOG(INFO)<<"Finished"; +    if (vm.count("cmc")) +    { +        LOG(INFO)<< "CMC accounted over " <<cmc.getCount() << " queries"; +        *ot << "CMC"<<endl; +        vector<double> cur = cmc.getAcc(); +        for (int i = 0;i<CMC_MAX;++i) +        *ot << cur[i]<<endl; +    }      if (vm.count("output"))          fout.close();      dp.close(); @@ -88,9 +104,10 @@ int main(int argc, char **argv) {              ("train,T", "training model")              ("validate,V", "validate model")              ("predict,P", "use model for prediction") -            ("model,m", po::value<std::string>(), "set input model file") -            ("output,o", po::value<std::string>(), "set output model/prediction file") -            ("feature,i", po::value<std::string>(), "set input feature file"); +            ("cmc,C", "enable cmc auditing") +            ("model,m", po::value<string>(), "set input model file") +            ("output,o", po::value<string>(), "set output model/prediction file") +            ("feature,i", po::value<string>(), "set input feature file");      // Parsing program options      po::store(po::parse_command_line(argc, argv, desc), vm); @@ -98,7 +115,7 @@ int main(int argc, char **argv) {      // Print help if necessary      if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) { -        std::cout << desc; +        cout << desc;          return 0;      }      mainFunc mainf; @@ -110,10 +127,10 @@ int main(int argc, char **argv) {      }      else return 0;      DataProvider* dp; -    if (vm["feature"].as<std::string>().find(".rid") == std::string::npos) -        dp = new FileDP(vm["feature"].as<std::string>()); +    if (vm["feature"].as<string>().find(".rid") == string::npos) +        dp = new FileDP(vm["feature"].as<string>());      else -        dp = new RidFileDP(vm["feature"].as<std::string>()); +        dp = new RidFileDP(vm["feature"].as<string>());      mainf(*dp);      delete dp;      return 0; diff --git a/model/rankaccu.cpp b/model/rankaccu.cpp index 069e245..1763e0f 100644 --- a/model/rankaccu.cpp +++ b/model/rankaccu.cpp @@ -112,4 +112,35 @@ int rank_accu(DataList &D,const vector<double> pred)          ++i;      }      LOG(INFO)<<"over "<< cnt<< " queries. "<<"Average nDGC: "<< accu_nDCG/cnt<< " Average AP: "<<accu_AP/cnt; +} + +int rank_CMC(DataList &D,const std::vector<double> pred,CMC & cmc) { +    unsigned long n = D.getSize(); +    vector<int> orig_rank(n),pred_rank(n),C(n); +    vector<double> orig(n); +    int i,j; +    for (i=0;i<D.getSize();++i) +    { +        orig_rank[i]=i; +        pred_rank[i]=i; +        orig[i]=D.getData()[i]->rank; +    } +    int cnt=0; +    i=j=0; +    while (i<D.getSize()) +    { +        if ((i+1 == D.getSize())|| D.getData()[i]->qid!=D.getData()[i+1]->qid) +        { +            ranksort(j,i,pred_rank,pred,orig); +            for (int k=j;k<=i;++k) +                if (orig[pred_rank[k]]>0) +                { +                    cmc.addEntry(k-j); +                    break; // account only for the first match; +                } +            j = i+1; +            ++cnt; +        } +        ++i; +    }  }
\ No newline at end of file diff --git a/model/rankaccu.h b/model/rankaccu.h index 3fe5379..adf1a1f 100644 --- a/model/rankaccu.h +++ b/model/rankaccu.h @@ -8,6 +8,36 @@  #include<vector>  #include"../tools/dataProvider.h" +#define CMC_MAX 100 + +class CMC +{ +private: +    std::vector<double> acc; +    int cnt; +public: +    void clear(){for (int i=0;i<CMC_MAX;++i) acc[i]=0; cnt=0;}; +    CMC(){acc.reserve(CMC_MAX); clear();}; +    void addEntry(int idx) { ++cnt; if (idx <CMC_MAX) acc[idx]+=1;} +    std::vector<double> getAcc() { +        std::vector<double> res; +        res.reserve(CMC_MAX); +        double cumu = 0; +        for (int i=0;i<CMC_MAX;++i) +        { +            cumu += acc[i]; +            res[i] = cumu / cnt; +        } +        return res; +    } +    int getCount() +    { +        return cnt; +    } +}; + +int rank_CMC(DataList &D,const std::vector<double> pred,CMC & cmc); +  int rank_accu(DataList &D,const std::vector<double> pred);  #endif //RANKSVM_RANKACCU_H  | 
