summaryrefslogtreecommitdiffstats
path: root/daemon/client.c
diff options
context:
space:
mode:
Diffstat (limited to 'daemon/client.c')
-rw-r--r--daemon/client.c292
1 files changed, 292 insertions, 0 deletions
diff --git a/daemon/client.c b/daemon/client.c
new file mode 100644
index 0000000..d4e6c3b
--- /dev/null
+++ b/daemon/client.c
@@ -0,0 +1,292 @@
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+
+#include <daemon-log.h>
+
+#include "client.h"
+#include "main.h"
+#include "packet.h"
+#include "ipqapi.h"
+
+#define MAX_QUEUE_LENGTH 10
+#define BUFSIZE (10*1024)
+
+static int unix_socket = -1, client_socket = -1;
+
+static GSList *send_queue = NULL;
+static guint send_queue_length = 0;
+static guint send_index = 0;
+
+static guint8 recv_buf[BUFSIZE];
+static guint recv_index = 0;
+
+int client_init() {
+ struct sockaddr_un addr;
+ mode_t save;
+
+ save = umask(0007);
+
+ if ((unix_socket = socket(PF_LOCAL, SOCK_STREAM, 0)) < 0) {
+ umask(save);
+ daemon_log(LOG_ERR, "socket(PF_LOCAL, SOCK_STREAM, 0): %s", strerror(errno));
+ return -1;
+ }
+
+ addr.sun_family = AF_LOCAL;
+ strncpy(addr.sun_path, SOCKET_PATH, sizeof(addr.sun_path));
+ addr.sun_path[sizeof(addr.sun_path)-1] = 0;
+
+ if (bind(unix_socket, (struct sockaddr *) &addr, SUN_LEN(&addr)) < 0) {
+ close(unix_socket);
+ umask(save);
+ daemon_log(LOG_ERR, "bind(): %s", strerror(errno));
+ return (unix_socket = -1);
+ }
+
+ umask(save);
+
+ if (listen(unix_socket, 1) < 0) {
+ close(unix_socket);
+ daemon_log(LOG_ERR, "listen(): %s", strerror(errno));
+ return (unix_socket = -1);
+ }
+
+ FD_SET(unix_socket, &listen_rfds);
+
+ return 0;
+}
+
+void client_disconnect() {
+ if (client_socket < 0)
+ return;
+
+ FD_CLR(client_socket, &listen_rfds);
+ close(client_socket);
+ client_socket = -1;
+
+ while (send_queue) {
+ message_t *m = (message_t*) send_queue->data;
+ send_queue = g_slist_remove(send_queue, m);
+ g_free(m);
+ }
+
+ send_queue_length = 0;
+ send_index = 0;
+
+ recv_index = 0;
+
+ daemon_log(LOG_INFO, "Client disconnected");
+}
+
+
+void client_done() {
+ if (client_socket >= 0)
+ client_disconnect();
+
+ if (unix_socket >= 0) {
+ FD_CLR(unix_socket, &listen_rfds);
+ close(unix_socket);
+ unlink(SOCKET_PATH);
+ }
+
+ unix_socket = -1;
+}
+
+
+int client_work_send() {
+ g_assert(client_socket >= 0);
+
+ if (send_queue) {
+ size_t l;
+ ssize_t r;
+ message_t *m = (message_t*) send_queue->data;
+
+ g_assert(m);
+ l = m->length + sizeof(message_t);
+
+ if ((r = write(client_socket, ((guint8*) m) + send_index, l-send_index)) <= 0) {
+ daemon_log(LOG_ERR, "Write error on client socket (%s)", strerror(errno));
+ return -1;
+ }
+
+ send_index += r;
+
+ if (send_index >= l) {
+ send_queue = g_slist_remove(send_queue, m);
+ send_queue_length--;
+ send_index = 0;
+
+ g_free(m);
+ }
+ }
+
+ if (!send_queue)
+ FD_CLR(client_socket, &listen_wfds);
+ else
+ FD_SET(client_socket, &listen_wfds);
+
+ return 0;
+}
+
+int client_dispatch(message_t*m) {
+ switch (m->code) {
+
+ case MSG_SET_DEFAULT_VERDICT:
+
+ if (m->length != sizeof(guint32)) {
+ daemon_log(LOG_WARNING, "Client sent MSG_SET_DEFAULT_VERDICT message with bogus size.");
+ return -1;
+ }
+
+ default_verdict = *((guint*) (m+1));
+ return 0;
+
+ case MSG_VERDICT: {
+ ipq_packet_msg_t *ipqm;
+ unsigned long packet_id;
+ guint32 verdict;
+
+ if (m->length != sizeof(unsigned long)+sizeof(guint32)) {
+ daemon_log(LOG_WARNING, "Client sent MSG_VERDICT message with bogus size.");
+ return -1;
+ }
+
+ packet_id = *((unsigned long*) (m+1));
+ verdict = *((guint32*) (((guint8*) (m+1)) + sizeof(unsigned long)));
+
+ if (log_packets)
+ daemon_log(LOG_DEBUG, "[%lu] Recieved client verdict %u", packet_id, verdict);
+
+ if ((ipqm = packet_find(packet_id))) {
+ if (ipqapi_verdict(ipqm, verdict) < 0) {
+ daemon_log(LOG_ERR, "Could not verdict.");
+ fail = TRUE;
+ }
+
+ packet_release(packet_id);
+
+ } else
+ daemon_log(LOG_WARNING, "Recieved verdict for unknown packet id, ignoring");
+
+ return 0;
+ }
+
+ default:
+ daemon_log(LOG_WARNING, "Recieved bogus message from client.");
+ return -1;
+ }
+}
+
+int client_work_recv() {
+ ssize_t r;
+ size_t l;
+
+
+ if (recv_index >= sizeof(message_t)) {
+ l = ((message_t*) recv_buf)->length + sizeof(message_t);
+
+ if (l > BUFSIZE) {
+ daemon_log(LOG_WARNING, "Client message too large");
+ return -1;
+ }
+ } else
+ l = sizeof(message_t);
+
+ if ((r = read(client_socket, recv_buf + recv_index, l - recv_index)) <= 0) {
+ if (r < 0)
+ daemon_log(LOG_WARNING, "Read error on client socket (%s)", strerror(errno));
+ return -1;
+ }
+
+ recv_index += r;
+
+ if (recv_index >= sizeof(message_t)) {
+ if (recv_index >= ((message_t*) recv_buf)->length + sizeof(message_t)) {
+ recv_index = 0;
+
+ if (client_dispatch((message_t*) recv_buf))
+ client_disconnect();
+ }
+ }
+ return 0;
+}
+
+
+int client_work_accept() {
+ int fd;
+
+ if ((fd = accept(unix_socket, NULL, NULL)) < 0)
+ return -1;
+
+ if (client_socket >= 0) {
+ daemon_log(LOG_WARNING, "Client connecting while already in use, closing");
+ close(fd);
+ return 0;
+ }
+
+ client_socket = fd;
+ FD_SET(client_socket, &listen_rfds);
+
+ daemon_log(LOG_INFO, "Client connected");
+
+ return 0;
+}
+
+int client_work() {
+ int r = 0;
+
+ if (FD_ISSET(unix_socket, &select_rfds))
+ if ((r = client_work_accept()) < 0)
+ return r;
+
+ if (client_socket >= 0 && FD_ISSET(client_socket, &select_rfds))
+ if (client_work_recv() < 0)
+ client_disconnect();
+
+ if (client_socket >= 0 && FD_ISSET(client_socket, &select_wfds))
+ if (client_work_send() < 0)
+ client_disconnect();
+
+ return 0;
+}
+
+int client_send_enqueue(message_t *m) {
+ g_assert(client_socket >= 0);
+
+ if (send_queue_length+1 > MAX_QUEUE_LENGTH)
+ return -1;
+
+ send_queue = g_slist_append(send_queue, m);
+ FD_SET(client_socket, &listen_wfds);
+
+ send_queue_length++;
+ return 0;
+}
+
+int client_is_connected() {
+ return client_socket >= 0;
+}
+
+message_t* message_new(message_code_t c, guint8* d, guint s) {
+ guchar *p;
+ message_t *m;
+
+ if (!d)
+ s = 0;
+
+ m = (message_t*) (p = g_new(guint8, sizeof(message_t) + s));
+
+ if (d)
+ memcpy(p + sizeof(message_t), d, s);
+
+ m->code = c;
+ m->pid = getpid();
+ m->length = s;
+
+ return m;
+}
+
+