diff options
-rw-r--r-- | model/rankaccu.h | 9 | ||||
-rw-r--r-- | train.cpp | 2 |
2 files changed, 9 insertions, 2 deletions
diff --git a/model/rankaccu.h b/model/rankaccu.h index 832a9f5..e8da882 100644 --- a/model/rankaccu.h +++ b/model/rankaccu.h @@ -15,10 +15,11 @@ class CMC private: std::vector<double> acc; int cnt; + double sum; public: - void clear(){for (int i=0;i<CMC_MAX;++i) acc[i]=0; cnt=0;}; + void clear(){for (int i=0;i<CMC_MAX;++i) acc[i]=0; cnt=0; sum=0;}; CMC(){acc.reserve(CMC_MAX); clear();}; - void addEntry(int idx) { ++cnt; if (idx <CMC_MAX) acc[idx]+=1;} + void addEntry(int idx) { ++cnt; if (idx <CMC_MAX) acc[idx]+=1; sum+=idx;} std::vector<double> getAcc() { std::vector<double> res; res.reserve(CMC_MAX); @@ -34,6 +35,10 @@ public: { return cnt; } + double getAvg() + { + return sum/cnt; + } }; void rank_CMC(RidList &D,const std::vector<double> pred,CMC & cmc); @@ -84,6 +84,8 @@ int predict(DataProvider &dp) { vector<double> cur = cmc.getAcc(); for (int i = 0;i<CMC_MAX;++i) *ot << cur[i]<<endl; + *ot << "AVG"<<endl; + *ot << cmc.getAvg()/D.getqSize() <<endl; } if (vm.count("output")) fout.close(); |