#define _GNU_SOURCE 1 // for POLLRDHUP && syncfs #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_SECCOMP #include #endif #include "ip_to_geo.h" #include "protocol.h" #define WAIT_TIME 100 #define MAX_WAIT_TIME 500 typedef struct { int socket; time_t timeout; int nb_remaining_requests; } socket_ctx_t; typedef struct thread_ctx_s{ struct thread_ctx_s* prev; struct thread_ctx_s* next; pthread_t thread; socket_ctx_t* sockets; int nb_cur_sockets; int nb_available_sockets; int max_timeout; int max_sockets; int stop; int quiet; pthread_mutex_t mutex; struct pollfd * pollfds; } thread_ctx_t; static pthread_mutex_t s_fastmutex = PTHREAD_MUTEX_INITIALIZER; static thread_ctx_t* s_last_thread = NULL; static int s_server_socket = -1; static int s_stop = 0; void sigint(int sig) { syslog(LOG_WARNING, "signal received, stopping threads"); s_stop = 1; shutdown(s_server_socket, SHUT_RDWR); } static int check_request(request_t* req) { if (req->magic != REQ_MAGIC) return REQ_ERR_BAD_MAGIC; if (req->version != REQ_VERSION) return REQ_ERR_BAD_VERSION; if (req->ip_type != REQ_IPV4 && req->ip_type != REQ_IPV6) return REQ_ERR_BAD_IP_VERSION; if (req->ip_type != REQ_IPV4) return REQ_ERR_UNSUPPORTED_IP_VERSION; if (req->req != REQ_REQ) return REQ_ERR_BAD_REQ_FIELD; return REQ_ERR_NO_ERR; } static int handle_request(thread_ctx_t* thread_ctx, int socket) { request_t req; const uint8_t* geo; int sent=0; int ret = read(socket, &req, sizeof(req)); // Socket closed if (ret == 0) { if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Socket %d closed", socket); return 1; } if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "New request"); if (ret != sizeof(req)) { if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Invalid request size %d", ret); return -1; } ret = check_request(&req); req.req = REQ_RESP; if (ret) { req.err = ret; if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Request error %d", ret); } else { if (thread_ctx->quiet < 0) { char dst[64]; inet_ntop((req.ip_type == REQ_IPV4)?AF_INET:AF_INET6, req.ip, dst, sizeof(dst)); syslog(LOG_DEBUG, "Request for %s from socket %d", dst, socket); } geo = ip_to_geo(req.ip, req.ip_type); if (!geo) { req.err = REQ_IP_NOT_FOUND; if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Not found"); } else { req.err = REQ_ERR_NO_ERR; geo = get_country_code(geo); if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Res %s", geo); req.country_code[0] = geo[0]; req.country_code[1] = geo[1]; req.country_code[2] = 0; req.country_code[3] = 0; } } for (sent=0; sent < sizeof(req); sent += ret) { ret = write(socket, &((uint8_t*)&req)[sent], sizeof(req)-sent); if (ret < 0) return -1; } return 0; } static void delete_thread(thread_ctx_t* thread_ctx) { int i; pthread_mutex_lock(&s_fastmutex); thread_ctx->nb_available_sockets = 0; if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Delete thread %p", thread_ctx); for(i=0; inb_cur_sockets; i++) { if (thread_ctx->sockets[i].timeout > 0) { close (thread_ctx->sockets[i].socket); } } free(thread_ctx->sockets); free(thread_ctx->pollfds); if (thread_ctx->next) thread_ctx->next->prev = thread_ctx->prev; if (thread_ctx->prev) thread_ctx->prev->next = thread_ctx->next; if (thread_ctx == s_last_thread) s_last_thread = thread_ctx->next; pthread_mutex_unlock(&s_fastmutex); free(thread_ctx); } static inline void close_socket(socket_ctx_t* socket) { socket->timeout = -1; close(socket->socket); } #define POLL_ERR_MASK (POLLRDHUP|POLLERR|POLLHUP|POLLNVAL) static void* thread_loop(void* param) { thread_ctx_t* ctx = (thread_ctx_t*)param; int i, ret, nfds, nb_cur_sockets, nb_available_sockets, poll_idx; struct timeval time1, time2, time_res; int wait_time = WAIT_TIME; while (!ctx->stop) { nfds = 0; pthread_mutex_lock(&ctx->mutex); nb_cur_sockets = ctx->nb_cur_sockets; nb_available_sockets = ctx->nb_available_sockets; pthread_mutex_unlock(&ctx->mutex); for(i=0; isockets[i].timeout > 0) { ctx->pollfds[nfds].fd = ctx->sockets[i].socket; ctx->pollfds[nfds].events = POLLIN|POLL_ERR_MASK; nfds++; } } if (!nfds) { /* No more active socket for this thread nor available slots */ if (!nb_available_sockets) break; if (wait_time < MAX_WAIT_TIME) wait_time += WAIT_TIME; usleep(wait_time); continue; } else wait_time = WAIT_TIME; gettimeofday(&time1, NULL); ret = poll(ctx->pollfds, nfds, ctx->max_timeout); gettimeofday(&time2, NULL); // Timeout, remove all current sockets if (ret == 0) { if (ctx->quiet < 0) syslog(LOG_DEBUG, "Timeout"); for(i=0; isockets[i].timeout > 0) close_socket(&ctx->sockets[i]); } } else if (ret < 0) { if (!s_stop && !ctx->stop) syslog(LOG_WARNING, "poll has errors (%m)\n"); break; } else { timersub(&time2, &time1, &time_res); poll_idx = -1; for(i=0; inb_cur_sockets; i++) { if (ctx->sockets[i].timeout < 0) continue; poll_idx++; if (ctx->pollfds[poll_idx].fd != ctx->sockets[i].socket) { if (ctx->quiet < 0) syslog(LOG_ERR, "Socket not found but present in poll fds"); continue; } // Error if (ctx->pollfds[poll_idx].revents & POLL_ERR_MASK) { if (ctx->quiet < 0) syslog(LOG_ERR, "Error with socket %d", ctx->sockets[i].socket); close_socket(&ctx->sockets[i]); } // Someone is speaking else if (ctx->pollfds[poll_idx].revents & POLLIN) { ctx->sockets[i].timeout = ctx->max_timeout*1000; ret = handle_request(ctx, ctx->sockets[i].socket); if (ret == 1) { if (ctx->quiet < 0) syslog(LOG_DEBUG, "Client has closed socket %d", ctx->sockets[i].socket); close_socket(&ctx->sockets[i]); } // No more requests accepted else if (!ctx->sockets[i].nb_remaining_requests--) { if (ctx->quiet < 0) syslog(LOG_DEBUG, "Max requests reached for socket %d", ctx->sockets[i].socket); syncfs(ctx->sockets[i].socket); close_socket(&ctx->sockets[i]); } } else { ctx->sockets[i].timeout -= (time_res.tv_sec*1000000 + time_res.tv_usec); if (ctx->sockets[i].timeout <= 0) close_socket(&ctx->sockets[i]); } } } }; delete_thread(ctx); pthread_exit(NULL); return NULL; } static inline thread_ctx_t* create_thread_ctx(struct gengetopt_args_info* params) { thread_ctx_t* thread_ctx = malloc(sizeof(*thread_ctx)); if (params->verbose_flag) syslog(LOG_DEBUG, "Create a new thread %p", thread_ctx); thread_ctx->sockets = malloc(sizeof(*thread_ctx->sockets)*params->sockets_per_thread_arg); thread_ctx->pollfds = malloc(sizeof(*thread_ctx->pollfds)*params->sockets_per_thread_arg); thread_ctx->nb_cur_sockets = 0; thread_ctx->nb_available_sockets = params->sockets_per_thread_arg; thread_ctx->max_timeout = params->sockets_timeout_arg*1000; thread_ctx->stop = 0; thread_ctx->quiet = params->quiet_flag; if (params->verbose_flag) thread_ctx->quiet = -1; thread_ctx->prev = NULL; pthread_mutex_init(&thread_ctx->mutex, NULL); thread_ctx->next = s_last_thread; if (s_last_thread) s_last_thread->prev = thread_ctx; else s_last_thread = thread_ctx; return thread_ctx; } static void fill_new_socket(struct gengetopt_args_info* params, int socket) { thread_ctx_t* thread_ctx; int launch_thread = 0; pthread_mutex_lock(&s_fastmutex); thread_ctx = s_last_thread; if (!thread_ctx || !thread_ctx->nb_available_sockets) { thread_ctx = create_thread_ctx(params); launch_thread = 1; } pthread_mutex_unlock(&s_fastmutex); thread_ctx->sockets[thread_ctx->nb_cur_sockets].socket = socket; thread_ctx->sockets[thread_ctx->nb_cur_sockets].timeout = thread_ctx->max_timeout*1000; // ms -> us thread_ctx->sockets[thread_ctx->nb_cur_sockets].nb_remaining_requests = params->client_max_requests_arg; pthread_mutex_lock(&thread_ctx->mutex); thread_ctx->nb_cur_sockets++; thread_ctx->nb_available_sockets--; pthread_mutex_unlock(&thread_ctx->mutex); if (launch_thread) pthread_create(&thread_ctx->thread, NULL, thread_loop, thread_ctx); } int daemonize(struct gengetopt_args_info* params) { int ret; struct sockaddr_in sockaddr; socklen_t sockaddr_len; int new_socket; void* thread_ret; // Should have both ipv4 & ipv6 s_server_socket = socket(AF_INET, SOCK_STREAM, 0); // Should have both TCP & UDP if (!s_server_socket) { if (!params->quiet_flag) fprintf(stderr, "Unable to create socket (%m)\n"); return -1; } memset(&sockaddr, 0, sizeof(sockaddr)); sockaddr.sin_family = AF_INET; // Should detect interface type (v4 or v6) sockaddr.sin_port = htons(params->port_arg); if (params->bind_ip_given) { ret = inet_aton(params->bind_ip_arg, &sockaddr.sin_addr); if (ret) { if (!params->quiet_flag) fprintf(stderr, "Error with bind address %s (%m)\n", params->bind_ip_arg); return -1; } } else sockaddr.sin_addr.s_addr = INADDR_ANY; ret = bind(s_server_socket, (struct sockaddr *)&sockaddr, sizeof(sockaddr)); if (ret) { if (!params->quiet_flag) fprintf(stderr, "Unable to bind (%m)\n"); return -2; } ret = listen(s_server_socket, 0); if (ret) { if (!params->quiet_flag) fprintf(stderr, "Unable to listen (%m)\n"); return -3; } if (!params->no_background_flag) { ret = daemon(0, 0); if (ret) { if (!params->quiet_flag) fprintf(stderr, "Daemon error (%m)\n"); return -4; } } openlog("ip_to_geod", 0, LOG_DAEMON); syslog(LOG_INFO, "ip_togeod started\n"); signal(SIGINT, sigint); signal(SIGUSR1, sigint); signal(SIGUSR2, sigint); #ifdef USE_SECCOMP scmp_filter_ctx seccomp_ctx = seccomp_init(SCMP_ACT_KILL); if (seccomp_ctx == NULL) { syslog(LOG_ERR, "unable to initialize seccomp\n"); return -5; } seccomp_rule_add(seccomp_ctx, SCMP_ACT_ALLOW, SCMP_SYS(read), 0); seccomp_rule_add(seccomp_ctx, SCMP_ACT_ALLOW, SCMP_SYS(write), 0); seccomp_rule_add(seccomp_ctx, SCMP_ACT_ALLOW, SCMP_SYS(close), 0); seccomp_rule_add(seccomp_ctx, SCMP_ACT_ALLOW, SCMP_SYS(accept), 0); #endif while (!s_stop) { sockaddr_len = sizeof(sockaddr); new_socket = accept(s_server_socket, (struct sockaddr *) &sockaddr, &sockaddr_len); if (new_socket < 0) { if (!s_stop) syslog(LOG_ERR, "accept error (%m), exiting"); break; } if (!params->quiet_flag) syslog(LOG_INFO, "new connection from %s, socket %d", inet_ntoa(sockaddr.sin_addr), new_socket); fill_new_socket(params, new_socket); } close(s_server_socket); while (s_last_thread) { s_last_thread->stop = 1; pthread_join(s_last_thread->thread, &thread_ret); } closelog(); #ifdef USE_SECCOMP if (seccomp_ctx) seccomp_release(seccomp_ctx); #endif return 0; }