diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-06-01 20:53:56 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-06-01 20:53:56 +0800 |
commit | 8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df (patch) | |
tree | b441797b9c494654d7f8607b4560a1a27e67dd26 | |
download | cross-8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df.tar.gz cross-8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df.tar.bz2 cross-8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df.zip |
init commit
-rw-r--r-- | main.py | 76 | ||||
-rw-r--r-- | misc.py | 30 |
2 files changed, 106 insertions, 0 deletions
@@ -0,0 +1,76 @@ +from misc import * + +clist=[0.001] + +inr="a.rid" + +folds=4 + +tot=316 + +for ite in range(10): + print("iter %d" % ite) + inm="%d.m" % ite + + o=ite+1 + + step = tot // folds + + ourb=inr + + for i in range(folds-1): + inra=ourb + oura="a%d.rid"%i + ourb="b%d.rid"%i + params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb) + split(params) + + inra=ourb + oura="a%d.rid"% (folds-1) + params = "-s -i %s -a %s"%(inra,oura) + split(params) + + for a in range(folds): + entries=['0','0'] + for b in range(folds): + if b!=a: + tmp=take("a%d.rid"%b) + entries = merge(tmp,entries) + rid="b%d.rid" %a + put(rid,entries) + + optc=1 + bcmc=-1 + + for c in clist: + print(c) + oum="%d-%g.m" % (o,c) + + acmc = -1; + for a in range(folds): + rid = "b%d.rid"%a + params = "-T -d -m %s -i %s -o %s -c %g --iter 1 > /dev/null" % (inm,rid,oum,c) + train(params) + params = "-V -C -m %s -i %s -s" %(oum,rid) + ncmc=cmc(params) + if acmc==-1: + acmc=ncmc + else: + for b in range(100): + acmc[b]+=ncmc[b] + + if bcmc==-1: + optc=c + bcmc=ncmc + else: + for b in range(100): + if bcmc[b]<ncmc[b]: + optc=c + bcmc=ncmc + break + + + oum="%d.m"% o + params = "-T -d -m %s -i %s -o %s -c %g --iter 1 > /dev/null" % (inm,inr,oum,optc) + train(params) + @@ -0,0 +1,30 @@ +from subprocess import (check_output,call) +import os + +def split(params): + devnull = open(os.devnull, 'w') + call(["./split"]+params.split(" "),stdout=devnull) + +def train(params): + devnull = open(os.devnull, 'w') + call(["./ranksvm"]+params.split(" "),stdout=devnull) + +def cmc(params): + devnull = open(os.devnull, 'w') + retcode = check_output(["./ranksvm"]+params.split(" ")); + return [float(i) for i in retcode.split('\n')[1:]] + +def take(fname): + f=open(fname,'r') + res=f.read().split('\n') + f.close() + return res + +def merge(a,b): + return a[:-1]+b[1:] + +def put(fname,a): + f=open(fname,'w') + for item in a: + f.write("%s\n" % item) + f.close() |