diff --git a/src/ip_to_geo.c b/src/ip_to_geo.c index 3c9f017..49c83e8 100644 --- a/src/ip_to_geo.c +++ b/src/ip_to_geo.c @@ -10,56 +10,72 @@ #include "ip_data.c" -static const uint8_t* ip_to_geo_rec(uint32_t ipv4, unsigned level, const ip_level* root) +static const uint8_t* ip_to_geo_rec(uint8_t* ip, unsigned level, const ip_level* root) { unsigned cur_average; - const ip_level* cur_ip = root; - unsigned cur_addr = (ipv4 >> (level*8)) & 0xFF; + const ip_level* cur_ip; + unsigned cur_addr; - // Optimistic search - if (cur_addr && level != 2) + while (1) { - cur_average = cur_addr >> root->average; + cur_ip = root; + cur_addr = ip[level]; - while (cur_average-- && cur_ip->next) - cur_ip = cur_ip->next; - } + // Optimistic search + if (cur_addr && level != 1) + { + cur_average = cur_addr >> root->average; + + while (cur_average-- && cur_ip->next) + cur_ip = cur_ip->next; + } #define IP_TEST \ - do { \ - if (cur_addr >= cur_ip->start && cur_addr <= cur_ip->end) \ { \ - if (cur_ip->childs) \ - return ip_to_geo_rec(ipv4, level-1, cur_ip->childs); \ - else \ - return &cur_ip->code; \ - } \ - } while (0) + if (cur_addr >= cur_ip->start && cur_addr <= cur_ip->end) \ + { \ + if (cur_ip->childs) \ + { \ + level++; \ + root = cur_ip->childs; \ + continue; \ + } \ + else \ + return &cur_ip->code; \ + } \ + } - if (cur_addr < cur_ip->start) - { - for (cur_ip = cur_ip->prev; cur_ip; cur_ip = cur_ip->prev) + if (cur_addr < cur_ip->start) + { + for (cur_ip = cur_ip->prev; cur_ip; cur_ip = cur_ip->prev) + IP_TEST; + } + else if (cur_addr > cur_ip->end) + { + for (cur_ip = cur_ip->next; cur_ip; cur_ip = cur_ip->next) + IP_TEST; + } + else IP_TEST; - } - else if (cur_addr > cur_ip->end) - { - for (cur_ip = cur_ip->next; cur_ip; cur_ip = cur_ip->next) - IP_TEST; - } - else - IP_TEST; + break; + } return NULL; } -const uint8_t* ip_to_geo(uint32_t ipv4) +const uint8_t* ip_to_geo(uint8_t* ip, unsigned ip_size) { - const ip_level* first_level = s_root_ip[ipv4 >> 24]; + const ip_level* first_level; + + if (ip_size == 4) + first_level = s_root_ip[ip[0]]; + else + return NULL; if (!first_level) return NULL; - return ip_to_geo_rec(ipv4, 2, first_level); + return ip_to_geo_rec(ip, 1, first_level); } const uint8_t* get_country_code(const uint8_t* idx) @@ -106,7 +122,7 @@ int interactive(struct gengetopt_args_info* params) return -1; } - cc = ip_to_geo((uint32_t)ret); + cc = ip_to_geo((uint8_t*)&ret, 4); if (params->quiet_flag) printf("%s\n", (cc)?(char*)get_country_code(cc):""); @@ -125,7 +141,7 @@ int main(int argc, char** argv) if (ret) return ret; - //self_test(); + /* self_test(); */ if (params.ip_given) return interactive(¶ms); diff --git a/src/ip_to_geo.h b/src/ip_to_geo.h index 44934a9..b074ce8 100644 --- a/src/ip_to_geo.h +++ b/src/ip_to_geo.h @@ -12,7 +12,7 @@ typedef struct ip_level_t { uint8_t code; } ip_level; -const uint8_t* ip_to_geo(uint32_t ipv4); +const uint8_t* ip_to_geo(uint8_t* ip, unsigned ip_size); const uint8_t* get_country_code(const uint8_t* idx); int strip_to_int(char* strip_, uint32_t* ip); diff --git a/src/protocol.h b/src/protocol.h index 1ed3de4..1822bc5 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -9,20 +9,12 @@ typedef struct { #define REQ_REQ 1 #define REQ_RESP 0 uint8_t req; -#define REQ_IPV4 32 -#define REQ_IPV6 128 +#define REQ_IPV4 4 +#define REQ_IPV6 16 uint8_t err; uint8_t ip_type; // 4 or 6 uint32_t flags; - union { - uint32_t ipv4; - struct { - uint32_t a; - uint32_t b; - uint32_t c; - uint32_t d; - }ipv6; - }; + uint8_t ip[16]; // ipv4 or ipv6 uint8_t country_code[4]; } request_t; diff --git a/src/server.c b/src/server.c index dca216d..03b6fc3 100644 --- a/src/server.c +++ b/src/server.c @@ -117,8 +117,13 @@ static int handle_request(thread_ctx_t* thread_ctx, int socket) else { if (thread_ctx->quiet < 0) - syslog(LOG_DEBUG, "Request for %08x from socket %d", req.ipv4, socket); - geo = ip_to_geo(req.ipv4); + { + char dst[64]; + inet_ntop((req.ip_type == REQ_IPV4)?AF_INET:AF_INET6, req.ip, dst, sizeof(dst)); + syslog(LOG_DEBUG, "Request for %s from socket %d", dst, socket); + } + + geo = ip_to_geo(req.ip, req.ip_type); if (!geo) { req.err = REQ_IP_NOT_FOUND; diff --git a/src/test.c b/src/test.c index c4e39a7..3268c13 100644 --- a/src/test.c +++ b/src/test.c @@ -6,8 +6,9 @@ static void do_test(int a, int b, int c, int d) { const uint8_t* cc; + uint8_t ip[4] = {a, b, c, d}; - cc = ip_to_geo(IP(a,b,c,d)); + cc = ip_to_geo(ip, 4); printf("IP %d.%d.%d.%d : %s\n", a, b, c, d, (cc)?(char*)get_country_code(cc):""); } diff --git a/tests/iptogeo.py b/tests/iptogeo.py index 3da3568..f887346 100644 --- a/tests/iptogeo.py +++ b/tests/iptogeo.py @@ -13,8 +13,8 @@ class IPToGeo(object): VERSION = 1 REQ = 1 RESP = 0 - IPV4 = 32 - IPV6 = 128 + IPV4 = 4 + IPV6 = 16 IP_NOT_FOUND = 6 @@ -40,13 +40,13 @@ class IPToGeo(object): self._socket.settimeout(self._timeout) self._socket.connect((self._remote_addr, self._remote_port)) - def _create_request(self, int_ip): + def _create_request(self, ip): packet = '' packet += struct.pack('