diff options
-rw-r--r-- | model/rankaccu.cpp | 19 | ||||
-rw-r--r-- | model/rankaccu.h | 2 | ||||
-rw-r--r-- | train.cpp | 13 |
3 files changed, 33 insertions, 1 deletions
diff --git a/model/rankaccu.cpp b/model/rankaccu.cpp index 8404abf..0f55e26 100644 --- a/model/rankaccu.cpp +++ b/model/rankaccu.cpp @@ -128,4 +128,23 @@ void rank_CMC(RidList &D,const std::vector<double> pred,CMC & cmc) { break; // account only for the first match; } } +} + +void rank_pair(RidList &D,const vector<double> pred,vector<double> &pair) +{ + int n =D.getSize(),q=D.getqSize(); + pair.clear(); + for (int i=0;i<n;i+=q) + { + int corr=0; + for (int j=0;j<q;++j) + if (D.getL(i+j)>0) { + corr = j; + break; + } + + for (int j=0;j<q;++j) + if (j!=corr) + pair.push_back(pred[i+corr]-pred[i+j]); + } }
\ No newline at end of file diff --git a/model/rankaccu.h b/model/rankaccu.h index e8da882..8cac56c 100644 --- a/model/rankaccu.h +++ b/model/rankaccu.h @@ -45,4 +45,6 @@ void rank_CMC(RidList &D,const std::vector<double> pred,CMC & cmc); void rank_accu(RidList &D,const std::vector<double> pred); +void rank_pair(RidList &D,const std::vector<double> pred,std::vector<double> &pair); + #endif //RANKSVM_RANKACCU_H @@ -72,9 +72,19 @@ int predict(DataProvider &dp) { rank_CMC(D,L,cmc); } - if (vm.count("output") && vm.count("predict")) + if (vm.count("predict")) + { + if (vm.count("pair")) + { + vector<double> pair; + rank_pair(D,L,pair); + for (int i=0;i<pair.size();++i) + *ot<<pair[i]<<endl; + } + else for (int i=0; i<L.size();++i) *ot<<L[i]<<endl; + } LOG(INFO)<<"Finished"; if (vm.count("cmc")) @@ -110,6 +120,7 @@ int main(int argc, char **argv) { ("cmc,C", "enable cmc auditing") ("debug,d", "show debug messages") ("single,s", "one from a pair") + ("pair,p","get pair result") ("model,m", po::value<string>(), "set input model file") ("output,o", po::value<string>(), "set output model/prediction file") ("feature,i", po::value<string>(), "set input feature file") |