From d9c104c731142c0347c6c476c3c8599410710ef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Soutad=C3=A9?= Date: Wed, 24 Feb 2016 19:17:40 +0100 Subject: [PATCH] Bad subgroup linking in build_c_array and recursion was stopped too early for ipv6 addresses --- data/build_c_array.py | 44 +++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/data/build_c_array.py b/data/build_c_array.py index 52ec32c..9a03835 100755 --- a/data/build_c_array.py +++ b/data/build_c_array.py @@ -9,7 +9,7 @@ IP_INDEX=3 IP_SIZE_INDEX=4 class IP_ELEMENT(object): - def __init__(self, start, end=None, size=0, country_code=None, level=0): + def __init__(self, start, end=None, size=0, country_code=None, level=0, is_group=False): self._start = start self._end = end self._size = size @@ -19,6 +19,7 @@ class IP_ELEMENT(object): self._childs = None self._average = 0 self._level = level + self._is_group = is_group if not self._end: self._compute_last_ip() @@ -50,7 +51,10 @@ class IP_ELEMENT(object): return self._separator.join([self._format % x for x in ip_val]) def name(self): - return 'ip__%s__%s' %(self._start.replace(self._separator, '_'), self._end.replace(self._separator, '_')) + name = 'ip__' + if self._is_group: + name += 'g%d_' % (self._level) + return name + '%s__%s' %(self._start.replace(self._separator, '_'), self._end.replace(self._separator, '_')) def _compute_last_ip(self): raise NotImplementedError() @@ -86,11 +90,11 @@ class IP_ELEMENT(object): class IP_ELEMENT4(IP_ELEMENT): - def __init__(self, start, end=None, size=0, country_code=None, level=0): + def __init__(self, start, end=None, size=0, country_code=None, level=0, is_group=False): self._separator = '.' self._base = 10 self._format = '%d' - super(IP_ELEMENT4, self).__init__(start, end, size, country_code, level) + super(IP_ELEMENT4, self).__init__(start, end, size, country_code, level, is_group) def get_ip_len(self): return 4 @@ -104,15 +108,14 @@ class IP_ELEMENT4(IP_ELEMENT): size = int(size/256) i += 1 self._end = self.ip_to_str(end_ip) - # print '%s + %d -> %s' % (self._start, self._size, self._end) class IP_ELEMENT6(IP_ELEMENT): - def __init__(self, start, end=None, size=0, country_code=None, level=0): + def __init__(self, start, end=None, size=0, country_code=None, level=0, is_group=False): self._separator = ':' self._base = 16 self._format = '%02x' - super(IP_ELEMENT6, self).__init__(start, end, size, country_code, level) + super(IP_ELEMENT6, self).__init__(start, end, size, country_code, level, is_group) def get_ip_len(self): return 16 @@ -179,7 +182,7 @@ print '/* This file was automatically generated, do not edit it ! */' print '#include \n\n' def ip_sort(a, b): - for i in range(0, len(a._splitted_start)): + for i in range(0, a.get_ip_len()): if a._splitted_start[i] != b._splitted_start[i]: return a._splitted_start[i] - b._splitted_start[i] return 0 @@ -201,7 +204,7 @@ def get_interval(root, intervals, level): def print_interval(interval): p = '[' for i in interval: - p += '%s, ' % (i.name()) + p += '%s,\n' % (i.name()) p += ']' return p @@ -213,6 +216,7 @@ def compute_average(root): total += 1 count += (child._splitted_end[child._level] - child._splitted_start[child._level] + 1) child = child._next + if not total: return average = int(count/total) # Find highest power of 2 < average for i in range(0, 9): @@ -220,12 +224,12 @@ def compute_average(root): root.set_average(i-1) break -def manage_root(root, intervals, level): +def manage_root(root, intervals, level, max_depth): cur_start = 0 prev = None first = None cur_len = 0 - if level >= 3: return (0, None) + if level >= max_depth: return None # print 'manage_root(%d, %s, %d)' %\ # (root, print_interval(intervals), level) while True: @@ -238,10 +242,10 @@ def manage_root(root, intervals, level): cur_ip.set_level(level+1) for ip in sub_interval: ip.set_level(level+1) - new_group = cur_ip.__class__(cur_ip.make_group(), level=level) + new_group = cur_ip.__class__(cur_ip.make_group(), level=level, is_group=True) sub_interval.insert(0, cur_ip) - child = manage_root(cur_ip._splitted_start[level+1], sub_interval, level+1) - new_group.set_childs(cur_ip) + child = manage_root(cur_ip._splitted_start[level+1], sub_interval, level+1, max_depth) + new_group.set_childs(child) compute_average(new_group) cur_ip = new_group cur_start += len(sub_interval) @@ -268,8 +272,8 @@ def print_ip(ip): cur_ip.printme() cur_ip = cur_ip._next -def build_array(ip_list, array_name): - ip_list.sort(ip_sort) +def build_array(ip_list, array_name, max_depth): + ip_list.sort(ip_sort) start_idx = 0 end_idx = start_idx+1 cur_interval = [ip_list[start_idx]] @@ -280,7 +284,7 @@ def build_array(ip_list, array_name): if end_idx >= len(ip_list): break if ip_list[end_idx]._splitted_start[0] != root: start_idx = end_idx - res = manage_root(root, cur_interval, 1) + res = manage_root(root, cur_interval, 1, max_depth) print_ip(res) root_ips[res._splitted_start[0]] = res cur_interval = [ip_list[end_idx]] @@ -288,7 +292,7 @@ def build_array(ip_list, array_name): else: cur_interval.append(ip_list[end_idx]) end_idx += 1 - res = manage_root(root, cur_interval, 1) + res = manage_root(root, cur_interval, 1, max_depth) print_ip(res) print '\nstatic const ip_level* %s[256] = {' % (array_name) @@ -299,8 +303,8 @@ def build_array(ip_list, array_name): print '\tNULL, // %d' % (i) print '};\n' -build_array(array_vals_ipv4.values(), 's_root_ipv4') -build_array(array_vals_ipv6.values(), 's_root_ipv6') +build_array(array_vals_ipv4.values(), 's_root_ipv4', 3) +build_array(array_vals_ipv6.values(), 's_root_ipv6', 15) print 'static const uint8_t country_codes[][3] = {' for cc in countries: