summaryrefslogtreecommitdiff
path: root/ensemble-train.py
diff options
context:
space:
mode:
Diffstat (limited to 'ensemble-train.py')
-rw-r--r--ensemble-train.py217
1 files changed, 110 insertions, 107 deletions
diff --git a/ensemble-train.py b/ensemble-train.py
index 1ceb380..34b4024 100644
--- a/ensemble-train.py
+++ b/ensemble-train.py
@@ -1,124 +1,127 @@
from misc import *
import math
-inr="a.rid"
-resm="res.m"
+
+def onefold(inr,inm,tot,clist,folds,resm)
+
+#inr="a.rid"
+#resm="res.m"
#consts used
-clist=[0.0001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 10, 100, 1000]
+#clist=[0.0001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 10, 100, 1000]
-folds=4
-tot=316
-step = tot // folds
-inm="0.m"
+#folds=4
+#tot=316
+ step = tot // folds
+#inm="0.m"
-ourb=inr
+ ourb=inr
#splits
-print("splitting")
-for i in range(folds-1):
- inra=ourb
- oura="a%d.rid"%i
- ourb="b%d.rid"%i
- params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb)
- split(params)
+ print("splitting")
+ for i in range(folds-1):
+ inra=ourb
+ oura="a%d.rid"%i
+ ourb="b%d.rid"%i
+ params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb)
+ split(params)
-entries=['0','0']
+ entries=['0','0']
-tmp=take("b0.rid")
-put("c0.rid",tmp)
+ tmp=take("b0.rid")
+ put("c0.rid",tmp)
-tmp=take("b%d.rid"%(folds-2))
-put("a%d.rid"%(folds-1),tmp)
-print("merging")
-for a in range(folds-1):
- tmp=take("a%d.rid"%a)
- entries = merge(tmp,entries)
-
- if a<folds-2:
- tmp=take("b%d.rid"%(a+1))
- tmp = merge(tmp,entries)
- else:
- tmp = entries;
-
- rid="c%d.rid" %(a+1)
- put(rid,tmp)
+ tmp=take("b%d.rid"%(folds-2))
+ put("a%d.rid"%(folds-1),tmp)
+ print("merging")
+ for a in range(folds-1):
+ tmp=take("a%d.rid"%a)
+ entries = merge(tmp,entries)
+
+ if a<folds-2:
+ tmp=take("b%d.rid"%(a+1))
+ tmp = merge(tmp,entries)
+ else:
+ tmp = entries;
+
+ rid="c%d.rid" %(a+1)
+ put(rid,tmp)
-for i in range(folds):
- entries = take("a%d.rid"%a)
- inra="c%d.rid"%i
- oura="a%d.rid"%i
- ourb="b%d.rid"%i
- params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb)
- split(params)
- entries = merge(take("a%d.rid"%a),entries)
-print("completed")
-wlist=[]
-mai=0
+ for i in range(folds):
+ entries = take("a%d.rid"%a)
+ inra="c%d.rid"%i
+ oura="a%d.rid"%i
+ ourb="b%d.rid"%i
+ params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb)
+ split(params)
+ entries = merge(take("a%d.rid"%a),entries)
+ print("completed")
+ wlist=[]
+ mai=0
#train
-for i in range(folds):
- for c in clist:
- print("folds: %d ,c: %g"%(i,c))
- wlist.append("%d-%g"%(i,c))
- oum="%d-%g.m" % (i,c)
- rid = "a%d.rid"%i
- params = "-T -d -m %s -i %s -o %s -c %g" % (inm,rid,oum,c)
- #train(params)
- oup="%d-%g.p"%(i,c)
- params = "-P -p -m %s -i %s -o %s" %(oum,inr,oup)
- #bare(params)
- for p in getpred(oup):
- mai=max(abs(p),mai);
+ for i in range(folds):
+ for c in clist:
+ print("folds: %d ,c: %g"%(i,c))
+ wlist.append("%d-%g"%(i,c))
+ oum="%d-%g.m" % (i,c)
+ rid = "a%d.rid"%i
+ params = "-T -d -m %s -i %s -o %s -c %g" % (inm,rid,oum,c)
+ #train(params)
+ oup="%d-%g.p"%(i,c)
+ params = "-P -p -m %s -i %s -o %s" %(oum,inr,oup)
+ #bare(params)
+ for p in getpred(oup):
+ mai=max(abs(p),mai);
-mai=mai*2;
+ mai=mai*2;
#inits
-D=getpred(oup)
-for i in range(len(D)):
- D[i]=1/len(D);
-mod=getmodel(oum);
-for i in range(len(mod)):
- mod[i]=0;
-
-while len(wlist)>0:
- low=1e20
- k=0
- P=[]
- #find best weak ranker
- for w in wlist:
- t=0
- pr=getpred(w+".p")
- for (d,p) in zip(D,pr):
- if p<=0:
- t+=d
- if t<low:
- low=t
- k=w
- P=pr
-
- print(k)
- wlist.remove(k)
- # cal alpha
- r=0;
- for (d,p) in zip(D,P):
- r+=d*p;
- r=r/mai;
- a=0.5*math.log((1+r)/(1-r))
-
- #update model
- tmod=getmodel(k+".m")
- for i in range(len(mod)):
- mod[i]+=a*tmod[i];
-
- #update D
+ D=getpred(oup)
for i in range(len(D)):
- D[i]=D[i]*math.exp(-a*P[i]);
-
- #normalize D
- acc=0;
- for d in D:
- acc+=d;
-
- for i in range(len(D)):
- D[i]/=acc;
+ D[i]=1/len(D);
+ mod=getmodel(oum);
+ for i in range(len(mod)):
+ mod[i]=0;
+
+ while len(wlist)>0:
+ low=1e20
+ k=0
+ P=[]
+ #find best weak ranker
+ for w in wlist:
+ t=0
+ pr=getpred(w+".p")
+ for (d,p) in zip(D,pr):
+ if p<=0:
+ t+=d
+ if t<low:
+ low=t
+ k=w
+ P=pr
+
+ print(k)
+ wlist.remove(k)
+ # cal alpha
+ r=0;
+ for (d,p) in zip(D,P):
+ r+=d*p;
+ r=r/mai;
+ a=0.5*math.log((1+r)/(1-r))
+
+ #update model
+ tmod=getmodel(k+".m")
+ for i in range(len(mod)):
+ mod[i]+=a*tmod[i];
+
+ #update D
+ for i in range(len(D)):
+ D[i]=D[i]*math.exp(-a*P[i]);
+
+ #normalize D
+ acc=0;
+ for d in D:
+ acc+=d;
+
+ for i in range(len(D)):
+ D[i]/=acc;
#output model
-print(mod)
-putmodel(resm,mod)
+ #print(mod)
+ putmodel(resm,mod)