summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-06-01 20:53:56 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-06-01 20:53:56 +0800
commit8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df (patch)
treeb441797b9c494654d7f8607b4560a1a27e67dd26
downloadcross-8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df.tar.gz
cross-8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df.tar.bz2
cross-8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df.zip
init commit
-rw-r--r--main.py76
-rw-r--r--misc.py30
2 files changed, 106 insertions, 0 deletions
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..6b04638
--- /dev/null
+++ b/main.py
@@ -0,0 +1,76 @@
+from misc import *
+
+clist=[0.001]
+
+inr="a.rid"
+
+folds=4
+
+tot=316
+
+for ite in range(10):
+ print("iter %d" % ite)
+ inm="%d.m" % ite
+
+ o=ite+1
+
+ step = tot // folds
+
+ ourb=inr
+
+ for i in range(folds-1):
+ inra=ourb
+ oura="a%d.rid"%i
+ ourb="b%d.rid"%i
+ params = "-c %d -i %s -a %s -b %s" %(step,inra,oura,ourb)
+ split(params)
+
+ inra=ourb
+ oura="a%d.rid"% (folds-1)
+ params = "-s -i %s -a %s"%(inra,oura)
+ split(params)
+
+ for a in range(folds):
+ entries=['0','0']
+ for b in range(folds):
+ if b!=a:
+ tmp=take("a%d.rid"%b)
+ entries = merge(tmp,entries)
+ rid="b%d.rid" %a
+ put(rid,entries)
+
+ optc=1
+ bcmc=-1
+
+ for c in clist:
+ print(c)
+ oum="%d-%g.m" % (o,c)
+
+ acmc = -1;
+ for a in range(folds):
+ rid = "b%d.rid"%a
+ params = "-T -d -m %s -i %s -o %s -c %g --iter 1 > /dev/null" % (inm,rid,oum,c)
+ train(params)
+ params = "-V -C -m %s -i %s -s" %(oum,rid)
+ ncmc=cmc(params)
+ if acmc==-1:
+ acmc=ncmc
+ else:
+ for b in range(100):
+ acmc[b]+=ncmc[b]
+
+ if bcmc==-1:
+ optc=c
+ bcmc=ncmc
+ else:
+ for b in range(100):
+ if bcmc[b]<ncmc[b]:
+ optc=c
+ bcmc=ncmc
+ break
+
+
+ oum="%d.m"% o
+ params = "-T -d -m %s -i %s -o %s -c %g --iter 1 > /dev/null" % (inm,inr,oum,optc)
+ train(params)
+
diff --git a/misc.py b/misc.py
new file mode 100644
index 0000000..7bf1370
--- /dev/null
+++ b/misc.py
@@ -0,0 +1,30 @@
+from subprocess import (check_output,call)
+import os
+
+def split(params):
+ devnull = open(os.devnull, 'w')
+ call(["./split"]+params.split(" "),stdout=devnull)
+
+def train(params):
+ devnull = open(os.devnull, 'w')
+ call(["./ranksvm"]+params.split(" "),stdout=devnull)
+
+def cmc(params):
+ devnull = open(os.devnull, 'w')
+ retcode = check_output(["./ranksvm"]+params.split(" "));
+ return [float(i) for i in retcode.split('\n')[1:]]
+
+def take(fname):
+ f=open(fname,'r')
+ res=f.read().split('\n')
+ f.close()
+ return res
+
+def merge(a,b):
+ return a[:-1]+b[1:]
+
+def put(fname,a):
+ f=open(fname,'w')
+ for item in a:
+ f.write("%s\n" % item)
+ f.close()