summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-06-02 11:42:11 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-06-02 11:42:11 +0800
commit63627a714d73ee4ec3e0b23755d96baf612948a9 (patch)
treede8948adac0d1e2cfe958fa012e81f9f352d6c22
parent8cee6cb53bf85e9a96e82d1116b2d7ac58d3c8df (diff)
downloadcross-63627a714d73ee4ec3e0b23755d96baf612948a9.tar.gz
cross-63627a714d73ee4ec3e0b23755d96baf612948a9.tar.bz2
cross-63627a714d73ee4ec3e0b23755d96baf612948a9.zip
cross-finished
-rw-r--r--main.py34
-rw-r--r--misc.py5
2 files changed, 16 insertions, 23 deletions
diff --git a/main.py b/main.py
index 6b04638..5fa9504 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,6 @@
from misc import *
-clist=[0.001]
+clist=[0.0001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 10, 100, 1000]
inr="a.rid"
@@ -40,37 +40,29 @@ for ite in range(10):
put(rid,entries)
optc=1
- bcmc=-1
+ bcmc=folds
for c in clist:
print(c)
oum="%d-%g.m" % (o,c)
- acmc = -1;
+ acmc = 0;
for a in range(folds):
rid = "b%d.rid"%a
- params = "-T -d -m %s -i %s -o %s -c %g --iter 1 > /dev/null" % (inm,rid,oum,c)
+ params = "-T -d -m %s -i %s -o %s -c %g --cg_iter 100" % (inm,rid,oum,c)
train(params)
+ rid="a%d.rid"%a
params = "-V -C -m %s -i %s -s" %(oum,rid)
- ncmc=cmc(params)
- if acmc==-1:
- acmc=ncmc
- else:
- for b in range(100):
- acmc[b]+=ncmc[b]
-
- if bcmc==-1:
+ acmc+=cmc(params)
+
+ print(acmc/folds)
+
+ if bcmc>acmc:
optc=c
- bcmc=ncmc
- else:
- for b in range(100):
- if bcmc[b]<ncmc[b]:
- optc=c
- bcmc=ncmc
- break
+ bcmc=acmc
-
+ print("train with: %g" % optc)
oum="%d.m"% o
- params = "-T -d -m %s -i %s -o %s -c %g --iter 1 > /dev/null" % (inm,inr,oum,optc)
+ params = "-T -d -m %s -i %s -o %s -c %g" % (inm,inr,oum,optc)
train(params)
diff --git a/misc.py b/misc.py
index 7bf1370..1a946ed 100644
--- a/misc.py
+++ b/misc.py
@@ -7,12 +7,13 @@ def split(params):
def train(params):
devnull = open(os.devnull, 'w')
+ params+=" --iter 1 --cg_prec 1e-4 --ls_prec 1e-10"
call(["./ranksvm"]+params.split(" "),stdout=devnull)
def cmc(params):
devnull = open(os.devnull, 'w')
- retcode = check_output(["./ranksvm"]+params.split(" "));
- return [float(i) for i in retcode.split('\n')[1:]]
+ retcode = check_output(["./ranksvm"]+params.split(" ")).decode('ascii').split('\n');
+ return float(retcode[-2])
def take(fname):
f=open(fname,'r')