summaryrefslogtreecommitdiff
path: root/model/rankaccu.cpp
blob: 2e77eb6fa211ccdd33103446c931d2c5d94006f0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
//
// Created by joe on 4/12/15.
//

#include "rankaccu.h"
#include "../tools/easylogging++.h"

using namespace std;

const double offset = 1;

void ranksort(int l,int r,vector<int> &rank,const vector<double> &ref1,const vector<double> &ref2)
{
    int i=l,j=r,k;
    double mid1=ref1[rank[(l+r)>>1]],mid2=ref2[rank[(l+r)>>1]];
    while (i<=j)
    {
        while (ref1[rank[i]]>mid1 || (ref1[rank[i]]==mid1 && ref2[rank[i]]>mid2)) ++i;
        while (ref1[rank[j]]<mid1 || (ref1[rank[j]]==mid1 && ref2[rank[j]]<mid2)) --j;
        if (i<=j)
        {
            k=rank[i];
            rank[i]=rank[j];
            rank[j]=k;
            ++i;
            --j;
        }
    }
    if (j>l)
        ranksort(l,j,rank,ref1,ref2);
    if (i<r)
        ranksort(i,r,rank,ref1,ref2);
}

int rank_accu(DataList &D,const vector<double> pred)
{
    unsigned long n = D.getSize();
    vector<int> orig_rank(n),pred_rank(n);
    vector<double> orig(n);
    int i,j;
    for (i=0;i<D.getSize();++i)
    {
        orig_rank[i]=i;
        pred_rank[i]=i;
        orig[i]=D.getData()[i]->rank;
    }
    int cnt=0;
    double accu_nDCG=0;
    i=j=0;
    while (i<D.getSize())
    {
        if ((i+1 == D.getSize())|| D.getData()[i]->qid!=D.getData()[i+1]->qid)
        {
            double Y=0,Z=0;
            ranksort(j,i,orig_rank,orig,pred);
            ranksort(j,i,pred_rank,pred,orig);
            for (int k = j;k<=i;++k)
            {
                Z += (pow(2,offset+orig[orig_rank[k]]) - 1)/log2(2+k-j);
                Y += (pow(2,offset+orig[pred_rank[k]]) - 1)/log2(2+k-j);
            }
            accu_nDCG+=Y/Z;
            j = i+1;
            ++cnt;
        }
        ++i;
    }
    LOG(INFO)<<"Average nDGC over "<< cnt<< " queries: "<< accu_nDCG/cnt;
}