diff options
| -rw-r--r-- | main.py | 76 | ||||
| -rw-r--r-- | misc.py | 30 | 
2 files changed, 106 insertions, 0 deletions
| @@ -0,0 +1,76 @@ +from misc import * + +clist=[0.001] + +inr="a.rid" + +folds=4 + +tot=316 + +for ite in range(10): +    print("iter %d" % ite) +    inm="%d.m" % ite +     +    o=ite+1 + +    step = tot // folds + +    ourb=inr + +    for i in range(folds-1): +        inra=ourb +        oura="a%d.rid"%i +        ourb="b%d.rid"%i +        params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb) +        split(params) +     +    inra=ourb +    oura="a%d.rid"% (folds-1) +    params = "-s -i %s -a %s"%(inra,oura) +    split(params) +     +    for a in range(folds): +        entries=['0','0'] +        for b in range(folds): +            if b!=a: +                tmp=take("a%d.rid"%b) +                entries = merge(tmp,entries) +        rid="b%d.rid" %a +        put(rid,entries) + +    optc=1 +    bcmc=-1 + +    for c in clist: +        print(c) +        oum="%d-%g.m" % (o,c) +         +        acmc = -1; +        for a in range(folds): +            rid = "b%d.rid"%a +            params = "-T -d -m %s -i %s -o %s -c %g --iter 1 > /dev/null" % (inm,rid,oum,c) +            train(params) +            params = "-V -C -m %s -i %s -s" %(oum,rid) +            ncmc=cmc(params) +            if acmc==-1: +                acmc=ncmc +            else: +                for b in range(100): +                    acmc[b]+=ncmc[b] +         +        if bcmc==-1: +            optc=c +            bcmc=ncmc +        else: +            for b in range(100): +                if bcmc[b]<ncmc[b]: +                    optc=c +                    bcmc=ncmc +                    break + +     +    oum="%d.m"% o +    params = "-T -d -m %s -i %s -o %s -c %g --iter 1 > /dev/null" % (inm,inr,oum,optc) +    train(params) + @@ -0,0 +1,30 @@ +from subprocess import (check_output,call) +import os + +def split(params): +    devnull = open(os.devnull, 'w') +    call(["./split"]+params.split(" "),stdout=devnull) + +def train(params): +    devnull = open(os.devnull, 'w') +    call(["./ranksvm"]+params.split(" "),stdout=devnull) + +def cmc(params): +    devnull = open(os.devnull, 'w') +    retcode = check_output(["./ranksvm"]+params.split(" ")); +    return [float(i) for i in retcode.split('\n')[1:]] + +def take(fname): +    f=open(fname,'r') +    res=f.read().split('\n') +    f.close() +    return res + +def merge(a,b): +    return a[:-1]+b[1:] + +def put(fname,a): +    f=open(fname,'w') +    for item in a: +        f.write("%s\n" % item) +    f.close() | 
