summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cross.py5
-rw-r--r--genCMC.py23
-rw-r--r--misc.py8
-rw-r--r--train.py8
4 files changed, 39 insertions, 5 deletions
diff --git a/cross.py b/cross.py
index 7b997b2..09674f2 100644
--- a/cross.py
+++ b/cross.py
@@ -49,7 +49,7 @@ def onefold(inr,inm,tot,clist,folds,resm):
acmc = 0;
for a in range(folds):
rid = "b%d.rid"%a
- params = "-T -d -m %s -i %s -o %s -c %g --cg_iter 100" % (inm,rid,oum,c)
+ params = "-T -d -m %s -i %s -o %s -c %g" % (inm,rid,oum,c)
train(params)
rid="a%d.rid"%a
params = "-V -C -m %s -i %s -s" %(oum,rid)
@@ -60,6 +60,9 @@ def onefold(inr,inm,tot,clist,folds,resm):
if bcmc>acmc:
optc=c
bcmc=acmc
+ else:
+ if bcmc< acmc:
+ break
print("train with: %g" % optc)
params = "-T -d -m %s -i %s -o %s -c %g" % (inm,inr,resm,optc)
diff --git a/genCMC.py b/genCMC.py
new file mode 100644
index 0000000..f8e3872
--- /dev/null
+++ b/genCMC.py
@@ -0,0 +1,23 @@
+from misc import *
+
+oura="a.rid"
+ourb="b.rid"
+inra="cam.rid"
+b=[]
+for i in range(100):
+ b.append(0);
+for s in range(6):
+ inm="res%d.m"%s
+ for i in range(100):
+ b[i]=0;
+ for i in range(5):
+ print("%d-%d"%(s,i))
+ params = "-c 316 -i %s -a %s -b %s" %(inra,oura,ourb)
+ split(params)
+ params = "-V -C -m %s -i %s -s" %(inm,ourb)
+ tmp = cmcc(params)
+ for j in range(100):
+ b[j]=b[j]+tmp[j]
+ for j in range(100):
+ b[j]=b[j]/5
+ print(b)
diff --git a/misc.py b/misc.py
index 84ce869..b546ad3 100644
--- a/misc.py
+++ b/misc.py
@@ -6,7 +6,7 @@ def split(params):
call(["./split"]+params.split(" "),stdout=devnull)
def train(params):
- params+=" --iter 1 --cg_prec 1e-4 --ls_prec 1e-10 --prec 1e-4"
+ params+=" --iter 1 --cg_prec 1e-4 --ls_prec 1e-10 --prec 1e-4 -M msk"
bare(params)
def bare(params):
@@ -18,6 +18,12 @@ def cmc(params):
retcode = check_output(["./ranksvm"]+params.split(" ")).decode('ascii').split('\n');
return float(retcode[-2])
+def cmcc(params):
+ devnull = open(os.devnull, 'w')
+ retcode = check_output(["./ranksvm"]+params.split(" ")).decode('ascii').split('\n');
+ return [float(x) for x in retcode[1:101]]
+
+
def take(fname):
f=open(fname,'r')
res=f.read().split('\n')
diff --git a/train.py b/train.py
index 82480e6..ca117b8 100644
--- a/train.py
+++ b/train.py
@@ -2,12 +2,14 @@ import ensemble
import cross
import time
-clist=[0.0001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 10, 100, 1000]
+clist=[0.1, 0.3, 1, 10, 30, 100, 300, 1e3, 3e3 , 1e4]
inr="a.rid"
tot=316
-folds=2
+folds=4
-for i in range(10):
+clist.reverse();
+
+for i in range(60):
print("iter %d:" %i)
start = time.time();
cross.onefold(inr,"%d.m"%i,316,clist,folds,"%d.m"%(i+1))