pytstat
changeset 135:5507731b5401
added tool to propagate endpoint label to flow level
| author | Alessandro Finamore <alessandro.finamore@ocracy.org> |
|---|---|
| date | Tue Oct 13 15:52:13 2009 +0200 (9 months ago) |
| parents | 2ff556cf365e |
| children | ac977bb4ca4d |
| files | geometry/roc.pyc geometry/svm.pyc gnuplot/cumulfreq.pyc lib/dataset.pyc lib/predict.py lib/predict.pyc lib/protocols.py lib/protocols.pyc lib/regex.py lib/regex.pyc makesvmtab.py other-tools/flow_label.py other-tools/gplt-payload.py other-tools/libpayload.py other-tools/readpayload.py parse/resultfile.pyc script-bash/cp-models.bash script-bash/init_envvar.bash splitter.py |
line diff
1.1 Binary file geometry/roc.pyc has changed
2.1 Binary file geometry/svm.pyc has changed
3.1 Binary file gnuplot/cumulfreq.pyc has changed
4.1 Binary file lib/dataset.pyc has changed
5.1 --- a/lib/predict.py Wed Jul 09 17:18:38 2008 +0200 5.2 +++ b/lib/predict.py Tue Oct 13 15:52:13 2009 +0200 5.3 @@ -10,6 +10,10 @@ 5.4 import re 5.5 from time import time 5.6 5.7 +##XXX 5.8 +ENTROPY = False 5.9 + 5.10 + 5.11 __obj = NetidStats() 5.12 5.13 def __init(data_dict, model, thr_model, threshold, 5.14 @@ -67,9 +71,11 @@ 5.15 5.16 for curr_proto in __obj.protos: 5.17 centroid = __obj.model[__obj.dir_class][curr_proto] 5.18 - #sample_scaled = apply_scale(sample, _step_sample) 5.19 - sample_scaled = scale_vector(sample, __obj.data_dict.tot_samples) 5.20 - centroid = scale_vector(centroid, __obj.data_dict.tot_samples) 5.21 + if not ENTROPY: 5.22 + sample_scaled = scale_vector(sample, __obj.data_dict.tot_samples) 5.23 + centroid = scale_vector(centroid, __obj.data_dict.tot_samples) 5.24 + else: 5.25 + sample_scaled = sample 5.26 5.27 curr_dist = euclidean_dist(centroid, sample_scaled) 5.28 report_string += " %7.3f" % (curr_dist) 5.29 @@ -112,7 +118,9 @@ 5.30 5.31 ## covert sample vector in svm format dictionar 5.32 __obj.stats.tot_test += 1 5.33 - sample = scale_vector(sample, __obj.data_dict.tot_samples) 5.34 + if not ENTROPY: 5.35 + sample = scale_vector(sample, __obj.data_dict.tot_samples) 5.36 + 5.37 svm_dict = svm.get_dict_format(sample) 5.38 5.39
6.1 Binary file lib/predict.pyc has changed
7.1 --- a/lib/protocols.py Wed Jul 09 17:18:38 2008 +0200 7.2 +++ b/lib/protocols.py Tue Oct 13 15:52:13 2009 +0200 7.3 @@ -8,17 +8,23 @@ 7.4 import sys 7.5 from gnuplot.colors import COLORS_DICT 7.6 7.7 -PROTO_NAMES = ["rtp", "edk", "port53"] 7.8 +PROTO_NAMES = [] 7.9 +PROTO_NAMES.append("rtp") 7.10 +PROTO_NAMES.append("edk") 7.11 +PROTO_NAMES.append("port53") 7.12 #PROTO_NAMES.append("skype") 7.13 + 7.14 +#PROTO_NAMES.append("pplive") 7.15 +#PROTO_NAMES.append("joost") 7.16 +#PROTO_NAMES.append("tvants") 7.17 +#PROTO_NAMES.append("sopcast") 7.18 +# 7.19 +#PROTO_NAMES.append("other") 7.20 PROTO_NAMES.append("backg") 7.21 #PROTO_NAMES.append("univer") 7.22 #PROTO_NAMES.append("genop") 7.23 -#PROTO_NAMES = ["rtp", "edk", "port53", "genunf"] 7.24 -#PROTO_NAMES = ["pplive", "joost", "tvants", "sopcast"] 7.25 -#PROTO_NAMES.append("univerFW") 7.26 PROTO_NOTSET = "-" 7.27 -#PROTO_NAMES = ["rtp", "edk", "port53", "unknow"] #, "port6348"] 7.28 -#PROTO_NAMES = ["rtp", "edk", "port53", "rtp-p", "edk-p", "port53-p", "unknow"] #, "port6348"] 7.29 + 7.30 PROTO_NAMES.sort() 7.31 7.32 #RTP = PROTO_NAMES.index("rtp")
8.1 Binary file lib/protocols.pyc has changed
9.1 --- a/lib/regex.py Wed Jul 09 17:18:38 2008 +0200 9.2 +++ b/lib/regex.py Tue Oct 13 15:52:13 2009 +0200 9.3 @@ -8,7 +8,7 @@ 9.4 spaces = re.compile(r"( |\t)+") 9.5 comments = re.compile(r"#.*") 9.6 comas = re.compile(r",") 9.7 -protocols = re.compile(r"|".join(protocols.PROTO_NAMES)) 9.8 +protocols = re.compile(r"|".join([ r'\b%s\b' % p for p in protocols.PROTO_NAMES])) 9.9 9.10 expression = "|".join(["(%s)" % el for el in colors.COLORS_AREA.keys()]) 9.11 chunk_area = re.compile(r"%s" % expression)
10.1 Binary file lib/regex.pyc has changed
11.1 --- a/makesvmtab.py Wed Jul 09 17:18:38 2008 +0200 11.2 +++ b/makesvmtab.py Tue Oct 13 15:52:13 2009 +0200 11.3 @@ -35,6 +35,9 @@ 11.4 for row in range(len(obj.data[0])): 11.5 l = ["%s:%s" % (ind + 1, float(obj.data[ind][row][1]) / data_dict.tot_samples) \ 11.6 for ind in range(data_dict.tot_chunks)] 11.7 + ##XXX 11.8 + #l = [ "%s:%s" % (ind + 1, obj.data[ind][row][1]) for ind in range(data_dict.tot_chunks) ] 11.9 + ##XXX 11.10 l.insert(0, label) 11.11 l.append("\n") 11.12 report.append(" ".join(l))
12.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000 12.2 +++ b/other-tools/flow_label.py Tue Oct 13 15:52:13 2009 +0200 12.3 @@ -0,0 +1,136 @@ 12.4 +#!/usr/bin/env python 12.5 +# -*- coding: utf-8 -*- 12.6 + 12.7 +import sys 12.8 +from os import path 12.9 + 12.10 +svmclasses = ["joost", "tvants", "sopcast", "pplive", "fw06"] 12.11 + 12.12 +_conflict = 0 12.13 +_line_num = 0 12.14 +def find_label(ipA, portA, ipB, portB, d_src, d_dst): 12.15 + global _conflict, _line_num 12.16 + epntA = "%s:%s" % (ipA, portA) 12.17 + epntB = "%s:%s" % (ipB, portB) 12.18 + 12.19 + labelA = None 12.20 + labelB = None 12.21 + _line_num += 1 12.22 + if d_src.has_key(epntA): 12.23 + labelA = d_src[epntA] 12.24 + if d_dst.has_key(epntB): 12.25 + labelB = d_dst[epntB] 12.26 + 12.27 + if labelA and labelB and labelA != labelB: 12.28 + _conflict += 1 12.29 + sys.stderr.write("[warn%d - line%d] conflict in label, src:%s dst:%s\n" % \ 12.30 + (_conflict, _line_num, labelA, labelB)) 12.31 + 12.32 + labelA = 'conflict' 12.33 + labelB = 'conflict' 12.34 + 12.35 + label = labelA 12.36 + if labelB: 12.37 + label = labelB 12.38 + return label 12.39 + 12.40 +def load_predid(fname, d): 12.41 + f = open(fname, "r") 12.42 + while True: 12.43 + line = f.readline() 12.44 + if line == '': 12.45 + break 12.46 + 12.47 + ## extract only labeled lines 12.48 + ## - 1st column is the endpoint label 12.49 + ## - 2nd column is the endpoint id 12.50 + valid = False 12.51 + for c in svmclasses: 12.52 + if line.startswith(c): 12.53 + valid = True 12.54 + break 12.55 + if not valid: 12.56 + continue 12.57 + 12.58 + words = line.split() 12.59 + label = words[0] 12.60 + endpoint = words[1] 12.61 + if d.has_key(endpoint) and d[endpoint] != label: 12.62 + print "'%s': conflict for endpoint=%s, %s or %s?" % \ 12.63 + (fname, endpoint, d[endpoint], label) 12.64 + sys.exit(1) 12.65 + d[endpoint] = label 12.66 + f.close() 12.67 + return d 12.68 + 12.69 +if __name__ == "__main__": 12.70 + ## check for command line arguments 12.71 + if len(sys.argv) < 3 or \ 12.72 + sys.argv[3].lower() != 'tab' and sys.argv[3].lower() != 'test': 12.73 + print sys.argv[0], '<tstat_log_flows> <kiss_proto_dir> tab|test' 12.74 + sys.exit(1) 12.75 + 12.76 + flow_log = sys.argv[1] 12.77 + pred_dir = sys.argv[2] 12.78 + ext = sys.argv[3].lower() 12.79 + 12.80 + # load endpoint classification 12.81 + in_src, in_dst = {}, {} 12.82 + out_src, out_dst = {}, {} 12.83 + fname = path.join(pred_dir, 'in', 'epnt_src.' + ext + '.svmpredid') 12.84 + in_src = load_predid(fname, in_src) 12.85 + fname = path.join(pred_dir, 'in', 'epnt_dst.' + ext + '.svmpredid') 12.86 + in_dst = load_predid(fname, in_dst) 12.87 + fname = path.join(pred_dir, 'out', 'epnt_src.' + ext + '.svmpredid') 12.88 + out_src = load_predid(fname, out_src) 12.89 + fname = path.join(pred_dir, 'out', 'epnt_dst.' + ext + '.svmpredid') 12.90 + out_dst = load_predid(fname, out_dst) 12.91 + 12.92 + # propagate endpoint classification to flow level 12.93 + f = open(flow_log) 12.94 + while True: 12.95 + line = f.readline() 12.96 + if line == '': 12.97 + break 12.98 + 12.99 + words = line.split() 12.100 + timestamp = words[2] 12.101 + ipA, portA = words[0:2] 12.102 + ipB, portB = words[8:10] 12.103 + internalA = bool(int(words[6])) 12.104 + internalB = bool(int(words[14])) 12.105 + pktA = int(words[5]) 12.106 + pktB = int(words[13]) 12.107 + bytesA = int(words[4]) 12.108 + bytesB = int(words[12]) 12.109 + 12.110 + ## OUT 12.111 + if internalA and not internalB: 12.112 + label = find_label(ipA, portA, ipB, portB, out_src, out_dst) 12.113 + 12.114 + ## IN 12.115 + elif not internalA and internalB: 12.116 + label = find_label(ipA, portA, ipB, portB, in_src, in_dst) 12.117 + 12.118 + ## LOCALE 12.119 + elif internalA and internalB: 12.120 + labelA = find_label(ipA, portA, ipB, portB, in_src, in_dst) 12.121 + labelB = find_label(ipA, portA, ipB, portB, out_src, out_dst) 12.122 + if labelA and labelB and labelA != labelB: 12.123 + print 'conflict %s-VS-%s line: %s'\ 12.124 + (labelA, labelB, line) 12.125 + sys.exit(1) 12.126 + label = labelA 12.127 + if not labelA and labelB: 12.128 + label = labelB 12.129 + 12.130 + ## ERRROR 12.131 + else: 12.132 + print "external flow!!! %s" % line 12.133 + sys.exit(1) 12.134 + 12.135 + if not label: 12.136 + label = 'notsupported' 12.137 + print "%s,%s,%s,%s,%s,%d,%d,udp,%s" % \ 12.138 + (timestamp, ipA, portA, ipB, portB, bytesA + bytesB, pktA + pktB, label) 12.139 + f.close()
13.1 --- a/other-tools/gplt-payload.py Wed Jul 09 17:18:38 2008 +0200 13.2 +++ b/other-tools/gplt-payload.py Tue Oct 13 15:52:13 2009 +0200 13.3 @@ -37,17 +37,24 @@ 13.4 if not opt.in_fname: 13.5 print "error: missing input file" 13.6 sys.exit(1) 13.7 - try: 13.8 - f = open(opt.in_fname, "r") 13.9 - except: 13.10 - print "error: '%s' No such file" % opt.in_fname 13.11 - sys.exit(1) 13.12 + 13.13 + if opt.in_fname == '-': 13.14 + opt.in_file = sys.stdin 13.15 + else: 13.16 + try: 13.17 + opt.in_file = open(opt.in_fname, "r") 13.18 + except: 13.19 + print "error: '%s' No such file" % opt.in_fname 13.20 + sys.exit(1) 13.21 13.22 13.23 if opt.out_dir == None: 13.24 - opt.out_dir = path.split(opt.in_fname)[0] 13.25 - if opt.out_dir == "": 13.26 - opt.out_dir = "." 13.27 + if opt.in_fname == '-': 13.28 + opt.out_dir = 'stdin' 13.29 + else: 13.30 + opt.out_dir = path.split(opt.in_fname)[0] 13.31 + if opt.out_dir == "": 13.32 + opt.out_dir = "." 13.33 ## check output directory 13.34 if not path.isdir(opt.out_dir): 13.35 print "error: '%s' No such directory" % opt.out_dir 13.36 @@ -184,28 +191,22 @@ 13.37 if __name__ == "__main__": 13.38 opt = cmdline_parse() 13.39 13.40 - try: 13.41 - file = open(opt.in_fname, "r") 13.42 - except: 13.43 - print "error opening '%s'" % opt.in_fname 13.44 - sys.exit(1) 13.45 - 13.46 regex_spaces = re.compile(" +") 13.47 bytes = -1 13.48 lines = 0 13.49 while True: 13.50 - line = file.readline() 13.51 + line = opt.in_file.readline() 13.52 if line == "": 13.53 break 13.54 13.55 line = line.strip() 13.56 - if line[0] == "#" or line[0] == "": 13.57 + if line == '' or line[0] == "#": 13.58 continue 13.59 if bytes == -1: 13.60 line = regex_spaces.sub(" ", line) 13.61 bytes = len(line.split()) - 2 13.62 lines += 1 13.63 - file.close() 13.64 + opt.in_file.close() 13.65 if lines == 0: 13.66 print "'%s' is void" % opt.in_fname 13.67 sys.exit(0)
14.1 --- a/other-tools/libpayload.py Wed Jul 09 17:18:38 2008 +0200 14.2 +++ b/other-tools/libpayload.py Tue Oct 13 15:52:13 2009 +0200 14.3 @@ -28,6 +28,10 @@ 14.4 if next_type == 0x11: 14.5 hdr, bin_data = self.decode_udp(bin_data) 14.6 14.7 + ## tcp 14.8 + if next_type == 0x06: 14.9 + hdr, bin_data = self.decode_tcp(bin_data) 14.10 + 14.11 self.data = [ self.b2i(byte) for byte in bin_data ] 14.12 if len(self.data) < 10: 14.13 data = "".join([ "%02x" % byte for byte in self.data ]) 14.14 @@ -38,8 +42,11 @@ 14.15 "DATA: %s" % data 14.16 14.17 def get_pay_len(self): 14.18 - ## valid only with UDP!!! 14.19 - return self.ip_len - self.ip_hdrlen - 8 14.20 + if hasattr(self, 'tcp_hdr_len'): 14.21 + return self.ip_len - self.ip_hdrlen - self.tcp_hdr_len * 4 14.22 + if hasattr(self, 'udp_len'): 14.23 + return self.ip_len - self.ip_hdrlen - 8 14.24 + return -1 14.25 14.26 def b2i(self, bin_val): 14.27 return int(hexlify(bin_val), 16) 14.28 @@ -52,7 +59,7 @@ 14.29 self.eth_dst_mac = ":".join([ hexlify(data[i]) for i in range(6,12) ]) 14.30 self.eth_type = self.b2i(data[12:14]) 14.31 self.out_string += \ 14.32 - "ETH : SRC: %s Dst: %s Type=%s\n" %\ 14.33 + "ETH : SRC:%s Dst:%s TYPE:%s\n" %\ 14.34 (self.eth_src_mac, self.eth_dst_mac, hex(self.eth_type)) 14.35 return data[0:14], data[14:] 14.36 14.37 @@ -83,20 +90,30 @@ 14.38 self.ip_src = ".".join([ str(self.b2i(data[i])) for i in range(12, 16) ]) 14.39 self.ip_dst = ".".join([ str(self.b2i(data[i])) for i in range(16, 20) ]) 14.40 self.out_string += \ 14.41 - "IPv%d: SRC: %s DST: %s Proto=%s\n" %\ 14.42 + "IPv%d: SRC:%s DST:%s Proto:%s\n" %\ 14.43 (self.ip_version, self.ip_src, self.ip_dst, hex(self.ip_proto)) 14.44 return data[0:self.ip_hdrlen], data[self.ip_hdrlen:] 14.45 14.46 def decode_udp(self, data): 14.47 - self.udp_src = self.b2i(data[0:2]) 14.48 - self.udp_dst = self.b2i(data[2:4]) 14.49 + self.port_src = self.b2i(data[0:2]) 14.50 + self.port_dst = self.b2i(data[2:4]) 14.51 self.udp_len = self.b2i(data[4:6]) 14.52 self.udp_checksum = self.b2i(data[6:8]) 14.53 self.out_string += \ 14.54 "UDP : SRC:%d DST:%d LEN:%d\n" %\ 14.55 - (self.udp_src, self.udp_dst, self.udp_len) 14.56 + (self.port_src, self.port_dst, self.udp_len) 14.57 return data[0:8], data[8:] 14.58 14.59 + def decode_tcp(self, data): 14.60 + self.port_src = self.b2i(data[0:2]) 14.61 + self.port_dst = self.b2i(data[2:4]) 14.62 + ##TO FIX!!! 14.63 + self.tcp_hdr_len= self.b2i(data[12]) >> 4 14.64 + self.out_string +=\ 14.65 + "TCP : SRC:%d DST:%d" %\ 14.66 + (self.port_src, self.port_dst) 14.67 + return data[0:self.tcp_hdr_len * 4], data[self.tcp_hdr_len * 4:] 14.68 + 14.69 def __str__(self): 14.70 return self.out_string 14.71
15.1 --- a/other-tools/readpayload.py Wed Jul 09 17:18:38 2008 +0200 15.2 +++ b/other-tools/readpayload.py Tue Oct 13 15:52:13 2009 +0200 15.3 @@ -6,21 +6,22 @@ 15.4 import sys,re 15.5 from os import path 15.6 15.7 -usage = \ 15.8 -"usage: %prog -r <dump_file> [-w file] [--bytes=num] [--offset=num] [--pckts=num]\n"\ 15.9 -" [--int][--diff] [tcpdump_filter_options]\n"\ 15.10 -"\n"\ 15.11 -" Use tshark packet inspector to extract packet payload from a dump file and generate\n"\ 15.12 -" a multi column report containing:\n"\ 15.13 -" - timestamp in unix format (\"seconds.msec\" since Jan 1, 1970 00:00:00);\n"\ 15.14 -" - payload length;\n"\ 15.15 -" - payload bytes (one for each column)\n"\ 15.16 -"\n"\ 15.17 -" dump_file can be filtered using --bytes, --offset and --pckts options to select\n"\ 15.18 -" a specific piece of payload but can also be used a generic wireshark filter\n"\ 15.19 -"\n"\ 15.20 -" Extracted bytes can be viewed as hexadecimal or integer values (--int) and also as\n"\ 15.21 -" delta increments between consecutives packets (--diff).\n" 15.22 +usage = ''' 15.23 +usage: %prog -r <dump_file> [-w file] [--bytes=num] [--offset=num] [--pckts=num] 15.24 + [--int][--diff][--ip] [tcpdump_filter_options] 15.25 + 15.26 + Use tshark packet inspector to extract packet payload from a dump file and generate 15.27 + a multi column report containing: 15.28 + - timestamp in unix format (\"seconds.msec\" since Jan 1, 1970 00:00:00); 15.29 + - payload length; 15.30 + - payload bytes (one for each column) 15.31 + 15.32 + dump_file can be filtered using --bytes, --offset and --pckts options to select 15.33 + a specific piece of payload but can also be used a generic wireshark filter 15.34 + 15.35 + Extracted bytes can be viewed as hexadecimal or integer values (--int) and also as 15.36 + delta increments between consecutives packets (--diff). 15.37 +''' 15.38 15.39 def cmdline_parse(): 15.40 parser = OptionParser() 15.41 @@ -49,6 +50,10 @@ 15.42 action='store_true', default=False, 15.43 help = 'print time elapsed from precedent packet') 15.44 15.45 + parser.add_option('', '--ip', dest='print_ip_info', 15.46 + action = 'store_true', default = False, 15.47 + help = 'print ip addresses and port numbers') 15.48 + 15.49 opt, other_opt = parser.parse_args() 15.50 15.51 ## check input file 15.52 @@ -123,7 +128,13 @@ 15.53 else: 15.54 timestamp = pck.timestamp - first_timestamp 15.55 first_timestamp = pck.timestamp 15.56 - line = "%14.6f %5d " % (timestamp, pck.get_pay_len()) 15.57 + 15.58 + line = '' 15.59 + if opt.print_ip_info: 15.60 + line += '%15s %5s %15s %5s ' % \ 15.61 + (pck.ip_src, pck.port_src, pck.ip_dst, pck.port_dst) 15.62 + 15.63 + line += "%14.6f %5d " % (timestamp, pck.get_pay_len()) 15.64 15.65 if opt.print_int: 15.66 line += " ".join([ "%3d" % val for val in data ])
16.1 Binary file parse/resultfile.pyc has changed
17.1 --- a/script-bash/cp-models.bash Wed Jul 09 17:18:38 2008 +0200 17.2 +++ b/script-bash/cp-models.bash Tue Oct 13 15:52:13 2009 +0200 17.3 @@ -9,7 +9,7 @@ 17.4 mod_name=`echo $mod_dir | sed -r 's/.*\/(.*)$/\1/'` 17.5 new_name=`echo $new_dir | sed -r 's/.*\/(.*)$/\1/'` 17.6 17.7 -cp -r $mod_dir $new_dir 17.8 +cp -r $mod_dir/* $new_dir 17.9 if [ $? -ne 0 ]; then 17.10 exit 1 17.11 fi
18.1 --- a/script-bash/init_envvar.bash Wed Jul 09 17:18:38 2008 +0200 18.2 +++ b/script-bash/init_envvar.bash Tue Oct 13 15:52:13 2009 +0200 18.3 @@ -2,27 +2,31 @@ 18.4 18.5 TOTCHUNKS=${TOTCHUNKS:-24} 18.6 TOTSAMPLES=${TOTSAMPLES:-80} 18.7 -TRAINSIZE=${TRAINSIZE:-400} 18.8 +TRAINSIZE=${TRAINSIZE:-300} 18.9 18.10 -DEFAULT_PROTOS="rtp edk port53" 18.11 -#DEFAULT_PROTOS="joost pplive sopcast tvants" 18.12 -#i= 18.13 -DEFAULT_PROTOS="$DEFAULT_PROTOS backg${i}0 backg${i}1 backg${i}2 backg${i}3 backg${i}4 backg${i}5 backg${i}6 backg${i}7 backg${i}8 backg${i}9" 18.14 -#DEFAULT_PROTOS="univer${i}0 univer${i}1 univer${i}2 univer${i}3 univer${i}4 univer${i}5 univer${i}6 univer${i}7 univer${i}8 univer${i}9" 18.15 +DEFAULT_PROTOS="rtp edk port53 backg" 18.16 +#DEFAULT_PROTOS="joost pplive sopcast tvants other skype" 18.17 +DEFAULT_PROTOS="$DEFAULT_PROTOS" 18.18 +#for i in {1..9}; do 18.19 +# DEFAULT_PROTOS="$DEFAULT_PROTOS backg$i" 18.20 +#done 18.21 18.22 PROTOS=${PROTOS:-$DEFAULT_PROTOS} 18.23 DIRS=${DIRS:-"in out"} 18.24 18.25 SVMMODELS=${SVMMODELS:-"svm2"} 18.26 #SVMMODELS="" 18.27 -#MODELS=${MODELS:-"euc $SVMMODELS"} 18.28 -MODELS=${MODELS:-"$SVMMODELS"} 18.29 +#MODELS="euc" 18.30 +#MODELS=${MODELS:-"euc"} 18.31 +MODELS=${MODELS:-"euc $SVMMODELS"} 18.32 18.33 +#THRPROTO=${THRPROTO:-"other"} 18.34 THRPROTO=${THRPROTO:-"backg"} 18.35 +#THRPROTO=${THRPROTO:-"aggr"} 18.36 #THRPROTO=${THRPROTO:-"univer"} 18.37 #THRPROTO=${THRPROTO:-"univerFW"} 18.38 THRIMGOPT=${THRIMG:-"-PG"} 18.39 -THRENABLED=${THRENABLED:-0} 18.40 +THRENABLED=${THRENABLED:-1} 18.41 18.42 PYTSTAT=${PYTSTAT:-"/home/fina/Tesi/src/pytstat"} 18.43
19.1 --- a/splitter.py Wed Jul 09 17:18:38 2008 +0200 19.2 +++ b/splitter.py Tue Oct 13 15:52:13 2009 +0200 19.3 @@ -23,6 +23,15 @@ 19.4 help="train set size [default=300]") 19.5 parser.add_option("-m", "", type = "int", default = -1, dest="test_size", 19.6 help="test set size (by default unbound)") 19.7 + parser.add_option('', '--disable-stratify', action = 'store_true', 19.8 + dest = 'disable_stratify', default = False, 19.9 + help = 'When input file is loaded, samples are reorganized by endpoint.'\ 19.10 + 'By default sampling is stratified selecting uniformly an endpoint'\ 19.11 + 'and then a sample from it. '\ 19.12 + 'Enabling this option allows to disable the stratification '\ 19.13 + 'so each line of the input file has the same probability '\ 19.14 + 'to be selected') 19.15 + 19.16 19.17 cmdline_opt, other = parser.parse() 19.18 19.19 @@ -82,10 +91,7 @@ 19.20 19.21 return train_set, test_set 19.22 19.23 - 19.24 -if __name__ == "__main__": 19.25 - cmdline_opt = parse_cmdline_opt() 19.26 - 19.27 +def stratify_sampling(cmdline_opt): 19.28 ## load samples from file 19.29 test_set = chilog.ChiDict( 19.30 is_epnt_format = cmdline_opt.is_epnt_format, 19.31 @@ -116,8 +122,66 @@ 19.32 s = names.get_out_fname(cmdline_opt.in_fname, "test", cmdline_opt.out_dir) 19.33 test_set.save(s, hdr) 19.34 19.35 +def simple_sampling(cmdline_opt): 19.36 + f = open(cmdline_opt.in_fname, 'r') 19.37 19.38 + ## load lines from input file 19.39 + lines = [] 19.40 + for line in f.readlines(): 19.41 + if line[0] == '#': 19.42 + continue 19.43 + if line == '': 19.44 + break 19.45 + lines.append(line) 19.46 + 19.47 + ## extract a random sequence of lines 19.48 + nums = range(len(lines)) 19.49 + rand_ind = [] 19.50 + while nums: 19.51 + n = random.randint(0, len(nums) - 1) 19.52 + rand_ind.append(nums.pop(n)) 19.53 19.54 + ## split random sequence in train and test portion 19.55 + train_ind = rand_ind[:cmdline_opt.train_size] 19.56 + if cmdline_opt.test_size < 0: 19.57 + test_ind = rand_ind[cmdline_opt.train_size:] 19.58 + else: 19.59 + test_ind = rand_ind[cmdline_opt.train_size : \ 19.60 + cmdline_opt.train_size + cmdline_opt.test_size] 19.61 19.62 + ## write train file 19.63 + hdr = "## file generated by '%s'\n"\ 19.64 + "## original file '%s'\n" %\ 19.65 + (check.program_name(), cmdline_opt.in_fname) 19.66 + s = names.get_out_fname(cmdline_opt.in_fname, "train", cmdline_opt.out_dir) 19.67 + f = open(s, 'w') 19.68 + f.write(hdr) 19.69 + for i in train_ind: 19.70 + f.write(lines[i]) 19.71 + f.close() 19.72 19.73 + ## write test file 19.74 + s = names.get_out_fname(cmdline_opt.in_fname, "test", cmdline_opt.out_dir) 19.75 + f = open(s, 'w') 19.76 + f.write(hdr) 19.77 + for i in test_ind: 19.78 + f.write(lines[i]) 19.79 + f.close() 19.80 19.81 + 19.82 + 19.83 + 19.84 +if __name__ == "__main__": 19.85 + cmdline_opt = parse_cmdline_opt() 19.86 + 19.87 + if cmdline_opt.disable_stratify: 19.88 + simple_sampling(cmdline_opt) 19.89 + else: 19.90 + stratify_sampling(cmdline_opt) 19.91 + 19.92 + 19.93 + 19.94 + 19.95 + 19.96 + 19.97 +
