#include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_SECCOMP #include #endif #include "ip_to_geo.h" #include "protocol.h" typedef struct { int socket; time_t timeout; int nb_remaining_requests; } socket_ctx_t; // TODO : sandbox typedef struct thread_ctx_s{ struct thread_ctx_s* prev; struct thread_ctx_s* next; pthread_t thread; socket_ctx_t* sockets; int nb_cur_sockets; int nb_available_sockets; int max_timeout; int max_sockets; int stop; int quiet; pthread_mutex_t mutex; } thread_ctx_t; static pthread_mutex_t s_fastmutex = PTHREAD_MUTEX_INITIALIZER; static thread_ctx_t* s_last_thread = NULL; static int s_server_socket = -1; static int s_stop = 0; void sigint(int sig) { syslog(LOG_WARNING, "signal received, stopping threads"); s_stop = 1; shutdown(s_server_socket, SHUT_RDWR); } // TODO signal capture static int check_request(request_t* req) { if (req->magic != REQ_MAGIC) return REQ_ERR_BAD_MAGIC; if (req->version != REQ_VERSION) return REQ_ERR_BAD_VERSION; if (req->ip_type != REQ_IPV4 && req->ip_type != REQ_IPV6) return REQ_ERR_BAD_IP_VERSION; if (req->ip_type != REQ_IPV4) return REQ_ERR_UNSUPPORTED_IP_VERSION; if (req->req != REQ_REQ) return REQ_ERR_BAD_REQ_FIELD; return REQ_ERR_NO_ERR; } static int handle_request(thread_ctx_t* thread_ctx, int socket) { request_t req; const uint8_t* geo; int ret = read(socket, &req, sizeof(req)); // Socket closed if (ret == 0) { if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Socket %d closed", socket); return 1; } if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "New request"); if (ret != sizeof(req)) { if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Invalid request size %d", ret); return -1; } ret = check_request(&req); req.req = REQ_RESP; if (ret) { req.err = ret; if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Request error %d", ret); } else { if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Request for %08x from socket %d", req.ipv4, socket); geo = ip_to_geo(req.ipv4); if (!geo) { req.err = REQ_IP_NOT_FOUND; if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Not found"); } else { req.err = REQ_ERR_NO_ERR; geo = get_country_code(geo); if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Res %s", geo); req.country_code[0] = geo[0]; req.country_code[1] = geo[1]; req.country_code[2] = 0; req.country_code[3] = 0; } } write(socket, &req, sizeof(req)); return 0; } static void delete_thread(thread_ctx_t* thread_ctx) { int i; pthread_mutex_lock(&s_fastmutex); thread_ctx->nb_available_sockets = 0; if (thread_ctx->quiet < 0) syslog(LOG_DEBUG, "Delete thread %p", thread_ctx); for(i=0; inb_cur_sockets; i++) { if (thread_ctx->sockets[i].timeout > 0) { close (thread_ctx->sockets[i].socket); } } free(thread_ctx->sockets); if (thread_ctx->next) thread_ctx->next->prev = thread_ctx->prev; if (thread_ctx->prev) thread_ctx->prev->next = thread_ctx->next; if (thread_ctx == s_last_thread) s_last_thread = thread_ctx->next; pthread_mutex_unlock(&s_fastmutex); free(thread_ctx); } static inline void close_socket(socket_ctx_t* socket) { socket->timeout = -1; close(socket->socket); } static void* thread_loop(void* param) { thread_ctx_t* ctx = (thread_ctx_t*)param; int i, ret, nfds; fd_set read_set, exc_set; struct timeval timeout; while (!ctx->stop) { FD_ZERO(&read_set); FD_ZERO(&exc_set); nfds = 0; for(i=0; inb_cur_sockets; i++) { if (ctx->sockets[i].timeout > 0) { FD_SET(ctx->sockets[i].socket, &read_set); FD_SET(ctx->sockets[i].socket, &exc_set); if (ctx->sockets[i].socket+1 > nfds) nfds = ctx->sockets[i].socket+1; } } if (!nfds) { // No more active socket for this thread if (!ctx->nb_available_sockets) break; usleep(100); continue; } timeout.tv_sec = ctx->max_timeout; timeout.tv_usec = 0; ret = select(nfds, &read_set, NULL, &exc_set, &timeout); pthread_mutex_lock(&ctx->mutex); // Timeout, remove all current sockets if (ret == 0) { if (ctx->quiet < 0) syslog(LOG_DEBUG, "Timeout"); for(i=0; inb_cur_sockets; i++) { if (ctx->sockets[i].timeout > 0) close_socket(&ctx->sockets[i]); } } else if (ret < 0) { if (!s_stop && !ctx->stop) syslog(LOG_WARNING, "select has errors (%m)\n"); } else { for(i=0; inb_cur_sockets; i++) { if (ctx->sockets[i].timeout < 0) continue; if (FD_ISSET(ctx->sockets[i].socket, &exc_set)) { close_socket(&ctx->sockets[i]); continue; } // Someone is speaking if (FD_ISSET(ctx->sockets[i].socket, &read_set)) { ctx->sockets[i].timeout = ctx->max_timeout; ret = handle_request(ctx, ctx->sockets[i].socket); if (ret == 1) { if (ctx->quiet < 0) syslog(LOG_DEBUG, "Client has closed socket %d", ctx->sockets[i].socket); close_socket(&ctx->sockets[i]); } // No more requests accepted if (!ctx->sockets[i].nb_remaining_requests--) { if (ctx->quiet < 0) syslog(LOG_DEBUG, "Max requests reached for socket %d", ctx->sockets[i].socket); syncfs(ctx->sockets[i].socket); close_socket(&ctx->sockets[i]); } } else { ctx->sockets[i].timeout -= timeout.tv_sec; } } } pthread_mutex_unlock(&ctx->mutex); }; delete_thread(ctx); pthread_exit(NULL); return NULL; } static inline thread_ctx_t* create_thread_ctx(struct gengetopt_args_info* params) { thread_ctx_t* thread_ctx = malloc(sizeof(*thread_ctx)); if (params->verbose_flag) syslog(LOG_DEBUG, "Create a new thread %p", thread_ctx); thread_ctx->sockets = malloc(sizeof(*thread_ctx->sockets)*params->sockets_per_thread_arg); thread_ctx->nb_cur_sockets = 0; thread_ctx->nb_available_sockets = params->sockets_per_thread_arg; thread_ctx->max_timeout = params->sockets_timeout_arg; thread_ctx->stop = 0; thread_ctx->quiet = params->quiet_flag; if (params->verbose_flag) thread_ctx->quiet = -1; thread_ctx->prev = NULL; pthread_mutex_init(&thread_ctx->mutex, NULL); thread_ctx->next = s_last_thread; if (s_last_thread) s_last_thread->prev = thread_ctx; else s_last_thread = thread_ctx; return thread_ctx; } static void fill_new_socket(struct gengetopt_args_info* params, int socket) { thread_ctx_t* thread_ctx; int launch_thread = 0; pthread_mutex_lock(&s_fastmutex); thread_ctx = s_last_thread; if (!thread_ctx || !thread_ctx->nb_available_sockets) { thread_ctx = create_thread_ctx(params); launch_thread = 1; } pthread_mutex_unlock(&s_fastmutex); thread_ctx->sockets[thread_ctx->nb_cur_sockets].socket = socket; thread_ctx->sockets[thread_ctx->nb_cur_sockets].timeout = thread_ctx->max_timeout; thread_ctx->sockets[thread_ctx->nb_cur_sockets].nb_remaining_requests = params->client_max_requests_arg; pthread_mutex_lock(&thread_ctx->mutex); thread_ctx->nb_cur_sockets++; thread_ctx->nb_available_sockets--; pthread_mutex_unlock(&thread_ctx->mutex); if (launch_thread) pthread_create(&thread_ctx->thread, NULL, thread_loop, thread_ctx); } int daemonize(struct gengetopt_args_info* params) { int ret; struct sockaddr_in sockaddr; socklen_t sockaddr_len; int new_socket; void* thread_ret; // Should have both ipv4 & ipv6 s_server_socket = socket(AF_INET, SOCK_STREAM, 0); // Should have both TCP & UDP if (!s_server_socket) { if (!params->quiet_flag) fprintf(stderr, "Unable to create socket (%m)\n"); return -1; } memset(&sockaddr, 0, sizeof(sockaddr)); sockaddr.sin_family = AF_INET; // Should detect interface type (v4 or v6) sockaddr.sin_port = htons(params->port_arg); if (params->bind_ip_given) { ret = inet_aton(params->bind_ip_arg, &sockaddr.sin_addr); if (ret) { if (!params->quiet_flag) fprintf(stderr, "Error with bind address %s (%m)\n", params->bind_ip_arg); return -1; } } else sockaddr.sin_addr.s_addr = INADDR_ANY; ret = bind(s_server_socket, (struct sockaddr *)&sockaddr, sizeof(sockaddr)); if (ret) { if (!params->quiet_flag) fprintf(stderr, "Unable to bind (%m)\n"); return -2; } ret = listen(s_server_socket, 0); if (ret) { if (!params->quiet_flag) fprintf(stderr, "Unable to listen (%m)\n"); return -3; } if (!params->no_background_flag) { ret = daemon(0, 0); if (ret) { if (!params->quiet_flag) fprintf(stderr, "Daemon error (%m)\n"); return -4; } } openlog("ip_to_geod", 0, LOG_DAEMON); syslog(LOG_INFO, "ip_togeod started\n"); signal(SIGINT, sigint); signal(SIGUSR1, sigint); signal(SIGUSR2, sigint); #ifdef USE_SECCOMP scmp_filter_ctx seccomp_ctx = seccomp_init(SCMP_ACT_KILL); if (seccomp_ctx == NULL) { syslog(LOG_ERR, "unable to initialize seccomp\n"); return -5; } seccomp_rule_add(seccomp_ctx, SCMP_ACT_ALLOW, SCMP_SYS(read), 0); seccomp_rule_add(seccomp_ctx, SCMP_ACT_ALLOW, SCMP_SYS(write), 0); seccomp_rule_add(seccomp_ctx, SCMP_ACT_ALLOW, SCMP_SYS(close), 0); seccomp_rule_add(seccomp_ctx, SCMP_ACT_ALLOW, SCMP_SYS(accept), 0); #endif while (!s_stop) { sockaddr_len = sizeof(sockaddr); new_socket = accept(s_server_socket, (struct sockaddr *) &sockaddr, &sockaddr_len); if (new_socket < 0) { if (!s_stop) syslog(LOG_ERR, "accept error (%m), exiting"); break; } if (!params->quiet_flag) syslog(LOG_INFO, "new connection from %s, socket %d", inet_ntoa(sockaddr.sin_addr), new_socket); fill_new_socket(params, new_socket); } close(s_server_socket); while (s_last_thread) { s_last_thread->stop = 1; pthread_join(s_last_thread->thread, &thread_ret); } closelog(); #ifdef USE_SECCOMP if (seccomp_ctx) seccomp_release(seccomp_ctx); #endif return 0; }