#include #include #include #include #include #include #include "client.h" #include "main.h" #include "packet.h" #include "ipqapi.h" #define MAX_QUEUE_LENGTH 10 #define BUFSIZE (10*1024) static int unix_socket = -1, client_socket = -1; static GSList *send_queue = NULL; static guint send_queue_length = 0; static guint send_index = 0; static guint8 recv_buf[BUFSIZE]; static guint recv_index = 0; int client_init() { struct sockaddr_un addr; mode_t save; save = umask(0007); if ((unix_socket = socket(PF_LOCAL, SOCK_STREAM, 0)) < 0) { umask(save); daemon_log(LOG_ERR, "socket(PF_LOCAL, SOCK_STREAM, 0): %s", strerror(errno)); return -1; } addr.sun_family = AF_LOCAL; strncpy(addr.sun_path, SOCKET_PATH, sizeof(addr.sun_path)); addr.sun_path[sizeof(addr.sun_path)-1] = 0; if (bind(unix_socket, (struct sockaddr *) &addr, SUN_LEN(&addr)) < 0) { close(unix_socket); umask(save); daemon_log(LOG_ERR, "bind(): %s", strerror(errno)); return (unix_socket = -1); } umask(save); if (listen(unix_socket, 1) < 0) { close(unix_socket); daemon_log(LOG_ERR, "listen(): %s", strerror(errno)); return (unix_socket = -1); } FD_SET(unix_socket, &listen_rfds); return 0; } void client_disconnect() { if (client_socket < 0) return; FD_CLR(client_socket, &listen_rfds); close(client_socket); client_socket = -1; while (send_queue) { message_t *m = (message_t*) send_queue->data; send_queue = g_slist_remove(send_queue, m); g_free(m); } send_queue_length = 0; send_index = 0; recv_index = 0; daemon_log(LOG_INFO, "Client disconnected"); } void client_done() { if (client_socket >= 0) client_disconnect(); if (unix_socket >= 0) { FD_CLR(unix_socket, &listen_rfds); close(unix_socket); unlink(SOCKET_PATH); } unix_socket = -1; } int client_work_send() { g_assert(client_socket >= 0); if (send_queue) { size_t l; ssize_t r; message_t *m = (message_t*) send_queue->data; g_assert(m); l = m->length + sizeof(message_t); if ((r = write(client_socket, ((guint8*) m) + send_index, l-send_index)) <= 0) { daemon_log(LOG_ERR, "Write error on client socket (%s)", strerror(errno)); return -1; } send_index += r; if (send_index >= l) { send_queue = g_slist_remove(send_queue, m); send_queue_length--; send_index = 0; g_free(m); } } if (!send_queue) FD_CLR(client_socket, &listen_wfds); else FD_SET(client_socket, &listen_wfds); return 0; } int client_dispatch(message_t*m) { switch (m->code) { case MSG_SET_DEFAULT_VERDICT: if (m->length != sizeof(guint32)) { daemon_log(LOG_WARNING, "Client sent MSG_SET_DEFAULT_VERDICT message with bogus size."); return -1; } default_verdict = *((guint*) (m+1)); return 0; case MSG_VERDICT: { ipq_packet_msg_t *ipqm; unsigned long packet_id; guint32 verdict; if (m->length != sizeof(unsigned long)+sizeof(guint32)) { daemon_log(LOG_WARNING, "Client sent MSG_VERDICT message with bogus size."); return -1; } packet_id = *((unsigned long*) (m+1)); verdict = *((guint32*) (((guint8*) (m+1)) + sizeof(unsigned long))); if (log_packets) daemon_log(LOG_DEBUG, "[%lu] Recieved client verdict %u", packet_id, verdict); if ((ipqm = packet_find(packet_id))) { if (ipqapi_verdict(ipqm, verdict) < 0) { daemon_log(LOG_ERR, "Could not verdict."); fail = TRUE; } packet_release(packet_id); } else daemon_log(LOG_WARNING, "Recieved verdict for unknown packet id, ignoring"); return 0; } default: daemon_log(LOG_WARNING, "Recieved bogus message from client."); return -1; } } int client_work_recv() { ssize_t r; size_t l; if (recv_index >= sizeof(message_t)) { l = ((message_t*) recv_buf)->length + sizeof(message_t); if (l > BUFSIZE) { daemon_log(LOG_WARNING, "Client message too large"); return -1; } } else l = sizeof(message_t); if ((r = read(client_socket, recv_buf + recv_index, l - recv_index)) <= 0) { if (r < 0) daemon_log(LOG_WARNING, "Read error on client socket (%s)", strerror(errno)); return -1; } recv_index += r; if (recv_index >= sizeof(message_t)) { if (recv_index >= ((message_t*) recv_buf)->length + sizeof(message_t)) { recv_index = 0; if (client_dispatch((message_t*) recv_buf)) client_disconnect(); } } return 0; } int client_work_accept() { int fd; if ((fd = accept(unix_socket, NULL, NULL)) < 0) return -1; if (client_socket >= 0) { daemon_log(LOG_WARNING, "Client connecting while already in use, closing"); close(fd); return 0; } client_socket = fd; FD_SET(client_socket, &listen_rfds); daemon_log(LOG_INFO, "Client connected"); return 0; } int client_work() { int r = 0; if (FD_ISSET(unix_socket, &select_rfds)) if ((r = client_work_accept()) < 0) return r; if (client_socket >= 0 && FD_ISSET(client_socket, &select_rfds)) if (client_work_recv() < 0) client_disconnect(); if (client_socket >= 0 && FD_ISSET(client_socket, &select_wfds)) if (client_work_send() < 0) client_disconnect(); return 0; } int client_send_enqueue(message_t *m) { g_assert(client_socket >= 0); if (send_queue_length+1 > MAX_QUEUE_LENGTH) return -1; send_queue = g_slist_append(send_queue, m); FD_SET(client_socket, &listen_wfds); send_queue_length++; return 0; } int client_is_connected() { return client_socket >= 0; } message_t* message_new(message_code_t c, guint8* d, guint s) { guchar *p; message_t *m; if (!d) s = 0; m = (message_t*) (p = g_new(guint8, sizeof(message_t) + s)); if (d) memcpy(p + sizeof(message_t), d, s); m->code = c; m->pid = getpid(); m->length = s; return m; }