summaryrefslogtreecommitdiff
path: root/ensemble-train.py
blob: 1ceb380446e16132ea4d874813936b8d2f078e69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from misc import *
import math
inr="a.rid"
resm="res.m"
#consts used
clist=[0.0001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 10, 100, 1000]

folds=4
tot=316
step = tot // folds
inm="0.m"

ourb=inr
#splits
print("splitting")
for i in range(folds-1):
    inra=ourb
    oura="a%d.rid"%i
    ourb="b%d.rid"%i
    params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb)
    split(params)

entries=['0','0']

tmp=take("b0.rid")
put("c0.rid",tmp)

tmp=take("b%d.rid"%(folds-2))
put("a%d.rid"%(folds-1),tmp)
print("merging")
for a in range(folds-1):
    tmp=take("a%d.rid"%a)
    entries = merge(tmp,entries)
	
    if a<folds-2:
        tmp=take("b%d.rid"%(a+1))
        tmp = merge(tmp,entries)
    else:
        tmp = entries;
	
    rid="c%d.rid" %(a+1)
    put(rid,tmp)

for i in range(folds):
    entries = take("a%d.rid"%a)
    inra="c%d.rid"%i
    oura="a%d.rid"%i
    ourb="b%d.rid"%i
    params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb)
    split(params)
    entries = merge(take("a%d.rid"%a),entries)
print("completed")
wlist=[]
mai=0
#train
for i in range(folds):
    for c in clist:
        print("folds: %d ,c: %g"%(i,c))
        wlist.append("%d-%g"%(i,c))
        oum="%d-%g.m" % (i,c)
        rid = "a%d.rid"%i
        params = "-T -d -m %s -i %s -o %s -c %g" % (inm,rid,oum,c)
        #train(params)
        oup="%d-%g.p"%(i,c)
        params = "-P -p -m %s -i %s -o %s" %(oum,inr,oup)
        #bare(params)
        for p in getpred(oup):
            mai=max(abs(p),mai);

mai=mai*2;

#inits
D=getpred(oup)
for i in range(len(D)):
    D[i]=1/len(D);
mod=getmodel(oum);
for i in range(len(mod)):
    mod[i]=0;

while len(wlist)>0:
    low=1e20
    k=0
    P=[]
    #find best weak ranker
    for w in wlist:
        t=0
        pr=getpred(w+".p")
        for (d,p) in zip(D,pr):
            if p<=0:
                t+=d
        if t<low:
            low=t
            k=w
            P=pr
    
    print(k)
    wlist.remove(k)
    # cal alpha
    r=0;
    for (d,p) in zip(D,P):
        r+=d*p;
    r=r/mai;
    a=0.5*math.log((1+r)/(1-r))
    
    #update model
    tmod=getmodel(k+".m")
    for i in range(len(mod)):
        mod[i]+=a*tmod[i];
    
    #update D
    for i in range(len(D)):
        D[i]=D[i]*math.exp(-a*P[i]);
    
    #normalize D
    acc=0;
    for d in D:
        acc+=d;
    
    for i in range(len(D)):
        D[i]/=acc;

#output model
print(mod)
putmodel(resm,mod)