#!/usr/bin/env python #-*- coding: utf-8 import sys from functools import cmp_to_key COUNTRY_CODE_INDEX=1 IP_TYPE_INDEX=2 IP_INDEX=3 IP_SIZE_INDEX=4 class IP_ELEMENT(object): def __init__(self, start, end=None, size=0, country_code=None, level=0, is_group=False): self._start = start self._end = end self._size = size self._country_code = country_code self._prev = None self._next = None self._childs = None self._average = 0 self._level = level self._is_group = is_group if not self._end: self._compute_last_ip() self._splitted_start = self.split_ip(self._start) self._splitted_end = self.split_ip(self._end) def split_ip(self, ip): return [int(x, self._base) for x in ip.split(self._separator)] def ip_to_str(self, int_ip): res = [] for i in range(0, self.get_ip_len()): res.insert(0, self._format % int((int_ip >> (i*8)) & 0xFF)) return self._separator.join(res) def ip_array_to_int(self, array): val = 0 for i in range(0, len(array)): val += array[len(array)-i-1] << (i*8) return val def ip_to_int(self, str_ip): return self.ip_array_to_int(self.split_ip(str_ip)) def make_group(self): ip_val = self._splitted_start[::] for i in range(self._level+1, self.get_ip_len()): ip_val[i] = 0 return self._separator.join([self._format % x for x in ip_val]) def name(self): name = 'ip__' if self._is_group: name += 'g%d_' % (self._level) return name + '%s__%s' %(self._start.replace(self._separator, '_'), self._end.replace(self._separator, '_')) def _compute_last_ip(self): raise NotImplementedError() def set_next(self, ip): self._next = ip def set_prev(self, ip): self._prev = ip def set_childs(self, ip): self._childs = ip def set_average(self, average): self._average = average def set_level(self, level): self._level = level def printme(self): print('static const ip_level %s = {' % (self.name())) print('\t.prev = %s,' % (self._prev and '&%s' % (self._prev.name()) or 'NULL')) print('\t.next = %s,' % (self._next and '&%s' % (self._next.name()) or 'NULL')) print('\t.childs = %s,' % (self._childs and '&%s' % (self._childs.name()) or 'NULL')) print('\t.start = %d,' % (self._splitted_start[self._level])) print('\t.end = %d,' % (self._splitted_end[self._level])) print('\t.average = %d,' % (self._average)) print('\t.code = %d,' % (self._country_code and self._country_code or 0)) print('};') def get_ip_len(self): raise NotImplementedError() class IP_ELEMENT4(IP_ELEMENT): def __init__(self, start, end=None, size=0, country_code=None, level=0, is_group=False): self._separator = '.' self._base = 10 self._format = '%d' super(IP_ELEMENT4, self).__init__(start, end, size, country_code, level, is_group) def get_ip_len(self): return 4 def _compute_last_ip(self): size = self._size end_ip = self.ip_to_int(self._start) i=0 while size > 0: end_ip += (((size % 256)-1) & 0xFF) << (i*8) size = int(size/256) i += 1 self._end = self.ip_to_str(end_ip) class IP_ELEMENT6(IP_ELEMENT): def __init__(self, start, end=None, size=0, country_code=None, level=0, is_group=False): self._separator = ':' self._base = 16 self._format = '%02x' super(IP_ELEMENT6, self).__init__(start, end, size, country_code, level, is_group) def get_ip_len(self): return 16 def _get_mask(self): mask = 0 for i in range(0, self._size): mask += 1 << i mask <<= 128-self._size return mask def _compute_last_ip(self): if self._size == 0: self._end = self._start[:] else: mask = self._get_mask() self._end = self.ip_to_str(self.ip_to_int(self._start) | ~mask) def extend_ipv6(ipv6): tmp = '' for s in ipv6.split(':'): if not s: break while len(s) != 4: s = '0' + s tmp += s while len(tmp) < 16*2: tmp += '0' res = '' for i in range(0, 15*2, 2): res += tmp[i] + tmp[i+1] + ':' res += tmp[30] + tmp[31] return res countries = [] f = open("prefix_res") array_vals_ipv4 = {} array_vals_ipv6 = {} while True: l = f.readline() # l = sys.stdin.readline() if not l: break information = l.split('|') country = information[COUNTRY_CODE_INDEX].lower() if not country: continue # Available or reserved but not assigned try: country_idx = countries.index(country) except ValueError: country_idx = len(countries) countries.append(country) ip = information[IP_INDEX] if information[IP_TYPE_INDEX] == 'ipv4': array_vals_ipv4[ip] = IP_ELEMENT4(ip, None, int(information[IP_SIZE_INDEX]), country_idx) elif information[IP_TYPE_INDEX] == 'ipv6': ip = extend_ipv6(ip) array_vals_ipv6[ip] = IP_ELEMENT6(ip, None, int(information[IP_SIZE_INDEX]), country_idx) else: sys.stderr.write('Unknown IP type %s\n' % (information[IP_TYPE_INDEX])) print('/* This file was automatically generated, do not edit it ! */') print('#include \n\n') def ip_sort(a, b): for i in range(0, a.get_ip_len()): if a._splitted_start[i] != b._splitted_start[i]: return a._splitted_start[i] - b._splitted_start[i] return 0 def get_interval(root, intervals, level): new_intervals = [] for ip in intervals: if ip._splitted_start[level] != root: break new_intervals.append(ip) return new_intervals def print_interval(interval): p = '[' for i in interval: p += '%s,\n' % (i.name()) p += ']' return p def compute_average(root): total = 0 count = 0 child = root._childs while child: total += 1 count += (child._splitted_end[child._level] - child._splitted_start[child._level] + 1) child = child._next if not total: return average = int(count/total) # Find highest power of 2 < average for i in range(0, 9): if average < (1 << i): root.set_average(i-1) break def manage_root(root, intervals, level, max_depth): cur_start = 0 prev = None first = None cur_len = 0 if level >= max_depth: return None # print 'manage_root(%d, %s, %d)' %\ # (root, print_interval(intervals), level) while True: if cur_start >= len(intervals): break cur_ip = intervals[cur_start] sub_interval = get_interval(cur_ip._splitted_start[level],\ intervals[cur_start+1:],\ level) if sub_interval: cur_ip.set_level(level+1) for ip in sub_interval: ip.set_level(level+1) new_group = cur_ip.__class__(cur_ip.make_group(), level=level, is_group=True) sub_interval.insert(0, cur_ip) child = manage_root(cur_ip._splitted_start[level+1], sub_interval, level+1, max_depth) new_group.set_childs(child) compute_average(new_group) cur_ip = new_group cur_start += len(sub_interval) else: cur_ip.set_level(level) cur_start += 1 cur_ip.set_prev(prev) if (prev): prev.set_next(cur_ip) prev = cur_ip if not first: first = cur_ip return first def print_ip(ip): cur_ip = ip while cur_ip: if cur_ip._childs: print_ip(cur_ip._childs) print('static const ip_level %s;' % (cur_ip.name())) cur_ip = cur_ip._next print('') cur_ip = ip while cur_ip: cur_ip.printme() cur_ip = cur_ip._next def build_array(ip_list, array_name, max_depth): ip_list = sorted(ip_list, key=cmp_to_key(ip_sort)) start_idx = 0 end_idx = start_idx+1 cur_interval = [ip_list[start_idx]] root = ip_list[start_idx]._splitted_start[0] root_ips = [None] * 256 while True: if end_idx >= len(ip_list): break if ip_list[end_idx]._splitted_start[0] != root: start_idx = end_idx res = manage_root(root, cur_interval, 1, max_depth) print_ip(res) root_ips[res._splitted_start[0]] = res cur_interval = [ip_list[end_idx]] root = ip_list[start_idx]._splitted_start[0] else: cur_interval.append(ip_list[end_idx]) end_idx += 1 res = manage_root(root, cur_interval, 1, max_depth) print_ip(res) print('\nstatic const ip_level* %s[256] = {' % (array_name)) for i in range(0, 256): if root_ips[i]: print('\t&%s,' % (root_ips[i].name())) else: print('\tNULL, // %d' % (i)) print('};\n') build_array(list(array_vals_ipv4.values()), 's_root_ipv4', 3) build_array(list(array_vals_ipv6.values()), 's_root_ipv6', 15) print('static const uint8_t country_codes[][3] = {') for cc in countries: print('\t{"%s"},' % (cc)) print('};\n')