summaryrefslogtreecommitdiffstats
path: root/src/query.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/query.c')
-rw-r--r--src/query.c154
1 files changed, 96 insertions, 58 deletions
diff --git a/src/query.c b/src/query.c
index 546d072..1b66df9 100644
--- a/src/query.c
+++ b/src/query.c
@@ -36,15 +36,33 @@
#include <sys/time.h>
#include <net/if.h>
#include <sys/ioctl.h>
+#include <stdlib.h>
#include "dns.h"
#include "util.h"
#include "query.h"
-static const usec_t retry_ms[] = { 200000, 500000, 900000, 1400000, 0 };
+static const usec_t retry_ms[] = { 500000, 1000000, 0 };
+
+static uint16_t get_random_id(void) {
+ uint16_t id = 0;
+ int ok = 0, fd;
+
+ if ((fd = open("/dev/urandom", O_RDONLY)) >= 0) {
+ ok = read(fd, &id, sizeof(id)) == 2;
+ close(fd);
+ }
+
+ if (!ok)
+ ok = random() & 0xFFFF;
+
+ return id;
+}
static void mdns_mcast_group(struct sockaddr_in *ret_sa) {
assert(ret_sa);
+
+ memset(ret_sa, 0, sizeof(struct sockaddr_in));
ret_sa->sin_family = AF_INET;
ret_sa->sin_port = htons(5353);
@@ -52,44 +70,30 @@ static void mdns_mcast_group(struct sockaddr_in *ret_sa) {
}
int mdns_open_socket(void) {
- struct ip_mreqn mreq;
struct sockaddr_in sa;
int fd = -1, ttl, yes;
- mdns_mcast_group(&sa);
-
if ((fd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) {
fprintf(stderr, "socket() failed: %s\n", strerror(errno));
goto fail;
}
-
+
ttl = 255;
if (setsockopt(fd, IPPROTO_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl)) < 0) {
fprintf(stderr, "IP_MULTICAST_TTL failed: %s\n", strerror(errno));
goto fail;
}
- yes = 1;
- if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)) < 0) {
- fprintf(stderr, "SO_REUSEADDR failed: %s\n", strerror(errno));
- goto fail;
- }
-
+ sa.sin_family = AF_INET;
+ sa.sin_port = 0;
+ sa.sin_addr.s_addr = INADDR_ANY;
+
if (bind(fd, (struct sockaddr*) &sa, sizeof(sa)) < 0) {
fprintf(stderr, "bind() failed: %s\n", strerror(errno));
goto fail;
}
-
- memset(&mreq, 0, sizeof(mreq));
- mreq.imr_multiaddr = sa.sin_addr;
- mreq.imr_address.s_addr = htonl(INADDR_ANY);
- mreq.imr_ifindex = 0;
- if (setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq)) < 0) {
- fprintf(stderr, "IP_ADD_MEMBERSHIP failed: %s\n", strerror(errno));
- goto fail;
- }
-
+ yes = 1;
if (setsockopt(fd, IPPROTO_IP, IP_RECVTTL, &yes, sizeof(yes)) < 0) {
fprintf(stderr, "IP_RECVTTL failed: %s\n", strerror(errno));
goto fail;
@@ -284,7 +288,7 @@ fail:
return ret;
}
-static int send_name_query(int fd, const char *name, int query_ipv4, int query_ipv6) {
+static int send_name_query(int fd, const char *name, uint16_t id, int query_ipv4, int query_ipv6) {
int ret = -1;
struct dns_packet *p = NULL;
uint8_t *prev_name = NULL;
@@ -297,6 +301,7 @@ static int send_name_query(int fd, const char *name, int query_ipv4, int query_i
goto finish;
}
+ dns_packet_set_field(p, DNS_FIELD_ID, id);
dns_packet_set_field(p, DNS_FIELD_FLAGS, DNS_FLAGS(0, 0, 0, 0, 0, 0, 0, 0, 0, 0));
#ifndef NSS_IP6_ONLY
@@ -354,7 +359,24 @@ static int domain_cmp(const char *a, const char *b) {
return strncasecmp(a, b, al);
}
-static int process_name_response(int fd, const char *name, usec_t timeout, void (*ipv4_func)(const ipv4_address_t *ipv4, void *userdata), void (*ipv6_func)(const ipv6_address_t *ipv6, void *userdata), void *userdata) {
+static int skip_questions(struct dns_packet *p) {
+ unsigned i;
+ assert(p);
+
+ for (i = dns_packet_get_field(p, DNS_FIELD_QDCOUNT); i > 0; i--) {
+ char pname[256];
+ uint16_t type, class;
+
+ if (dns_packet_consume_name(p, pname, sizeof(pname)) < 0 ||
+ dns_packet_consume_uint16(p, &type) < 0 ||
+ dns_packet_consume_uint16(p, &class) < 0)
+ return -1;
+ }
+
+ return 0;
+}
+
+static int process_name_response(int fd, const char *name, usec_t timeout, uint16_t id, void (*ipv4_func)(const ipv4_address_t *ipv4, void *userdata), void (*ipv6_func)(const ipv6_address_t *ipv6, void *userdata), void *userdata) {
struct dns_packet *p = NULL;
int done = 0;
struct timeval end;
@@ -374,12 +396,15 @@ static int process_name_response(int fd, const char *name, usec_t timeout, void
return 1;
/* Ignore packets with RFC != 255 */
- if (ttl == 255) {
+ if (/* ttl == 255 && */
+ dns_packet_check_valid_response(p) >= 0 &&
+ dns_packet_get_field(p, DNS_FIELD_ID) == id) {
+
+ unsigned i;
- /* Ignore corrupt packets */
- if (dns_packet_check_valid_response(p) >= 0) {
+ if (skip_questions(p) >= 0)
- for (;;) {
+ for (i = dns_packet_get_field(p, DNS_FIELD_ANCOUNT); i > 0; i--) {
char pname[256];
uint16_t type, class;
uint32_t rr_ttl;
@@ -392,7 +417,7 @@ static int process_name_response(int fd, const char *name, usec_t timeout, void
dns_packet_consume_uint16(p, &rdlength) < 0) {
break;
}
-
+
/* Remove mDNS cache flush bit */
class &= ~0x8000;
@@ -402,27 +427,27 @@ static int process_name_response(int fd, const char *name, usec_t timeout, void
class == DNS_CLASS_IN &&
!domain_cmp(name, pname) &&
rdlength == sizeof(ipv4_address_t)) {
-
+
ipv4_address_t ipv4;
if (dns_packet_consume_bytes(p, &ipv4, sizeof(ipv4)) < 0)
break;
-
+
ipv4_func(&ipv4, userdata);
done = 1;
}
#endif
#if ! defined(NSS_IPV6_ONLY) && ! defined(NSS_IPV4_ONLY)
- else
+/* else */
#endif
#ifndef NSS_IPV4_ONLY
- if (ipv6_func &&
- type == DNS_TYPE_AAAA &&
- class == DNS_CLASS_IN &&
- !domain_cmp(name, pname) &&
- rdlength == sizeof(ipv6_address_t)) {
-
+ if (ipv6_func &&
+ type == DNS_TYPE_AAAA &&
+ class == DNS_CLASS_IN &&
+ !domain_cmp(name, pname) &&
+ rdlength == sizeof(ipv6_address_t)) {
+
ipv6_address_t ipv6;
if (dns_packet_consume_bytes(p, &ipv6, sizeof(ipv6_address_t)) < 0)
@@ -432,35 +457,40 @@ static int process_name_response(int fd, const char *name, usec_t timeout, void
done = 1;
}
#endif
- else {
-
+ else {
+
/* Step over */
if (dns_packet_consume_seek(p, rdlength) < 0)
break;
}
}
- }
}
+
if (p)
dns_packet_free(p);
- }
+ }
+
return 0;
}
int mdns_query_name(int fd, const char *name, void (*ipv4_func)(const ipv4_address_t *ipv4, void *userdata), void (*ipv6_func)(const ipv6_address_t *ipv6, void *userdata), void *userdata) {
const usec_t *timeout = retry_ms;
+ uint16_t id;
+
assert(fd >= 0 && name && (ipv4_func || ipv6_func));
+ id = get_random_id();
+
while (*timeout > 0) {
int n;
- if (send_name_query(fd, name, !!ipv4_func, !!ipv6_func) < 0)
+ if (send_name_query(fd, name, id, !!ipv4_func, !!ipv6_func) < 0)
return -1;
- if ((n = process_name_response(fd, name, *timeout, ipv4_func, ipv6_func, userdata)) < 0)
+ if ((n = process_name_response(fd, name, *timeout, id, ipv4_func, ipv6_func, userdata)) < 0)
return -1;
if (n == 0)
@@ -474,7 +504,7 @@ int mdns_query_name(int fd, const char *name, void (*ipv4_func)(const ipv4_addre
return -1;
}
-static int send_reverse_query(int fd, const char *name) {
+static int send_reverse_query(int fd, const char *name, uint16_t id) {
int ret = -1;
struct dns_packet *p = NULL;
@@ -485,6 +515,7 @@ static int send_reverse_query(int fd, const char *name) {
goto finish;
}
+ dns_packet_set_field(p, DNS_FIELD_ID, id);
dns_packet_set_field(p, DNS_FIELD_FLAGS, DNS_FLAGS(0, 0, 0, 0, 0, 0, 0, 0, 0, 0));
if (!dns_packet_append_name(p, name)) {
@@ -506,7 +537,7 @@ finish:
return ret;
}
-static int process_reverse_response(int fd, const char *name, usec_t timeout, void (*name_func)(const char *name, void *userdata), void *userdata) {
+static int process_reverse_response(int fd, const char *name, usec_t timeout, uint16_t id, void (*name_func)(const char *name, void *userdata), void *userdata) {
struct dns_packet *p = NULL;
int done = 0;
struct timeval end;
@@ -526,12 +557,15 @@ static int process_reverse_response(int fd, const char *name, usec_t timeout, vo
return 1;
/* Ignore packets with RFC != 255 */
- if (ttl == 255) {
+ if (/* ttl == 255 && */
+ dns_packet_check_valid_response(p) >= 0 &&
+ dns_packet_get_field(p, DNS_FIELD_ID) == id) {
- /* Ignore corrupt packets */
- if (dns_packet_check_valid_response(p) >= 0) {
-
- for (;;) {
+ unsigned i;
+
+ if (skip_questions(p) >= 0) {
+
+ for (i = dns_packet_get_field(p, DNS_FIELD_ANCOUNT); i > 0; i--) {
char pname[256];
uint16_t type, class;
uint32_t rr_ttl;
@@ -544,24 +578,24 @@ static int process_reverse_response(int fd, const char *name, usec_t timeout, vo
dns_packet_consume_uint16(p, &rdlength) < 0) {
break;
}
-
+
/* Remove mDNS cache flush bit */
class &= ~0x8000;
if (type == DNS_TYPE_PTR &&
class == DNS_CLASS_IN &&
!domain_cmp(name, pname)) {
-
+
char rname[256];
-
+
if (dns_packet_consume_name(p, rname, sizeof(rname)) < 0)
break;
-
+
name_func(rname, userdata);
done = 1;
} else {
-
+
/* Step over */
if (dns_packet_consume_seek(p, rdlength) < 0)
@@ -580,15 +614,19 @@ static int process_reverse_response(int fd, const char *name, usec_t timeout, vo
static int query_reverse(int fd, const char *name, void (*name_func)(const char *name, void *userdata), void *userdata) {
const usec_t *timeout = retry_ms;
+ uint16_t id;
+
assert(fd >= 0 && name && name_func);
+ id = get_random_id();
+
while (*timeout > 0) {
int n;
- if (send_reverse_query(fd, name) <= 0) /* error or no interface to send data on */
+ if (send_reverse_query(fd, name, id) <= 0) /* error or no interface to send data on */
return -1;
- if ((n = process_reverse_response(fd, name, *timeout, name_func, userdata)) < 0)
+ if ((n = process_reverse_response(fd, name, *timeout, id, name_func, userdata)) < 0)
return -1;
if (n == 0)
@@ -616,7 +654,7 @@ int mdns_query_ipv4(int fd, const ipv4_address_t *ipv4, void (*name_func)(const
#endif
#ifndef NSS_IPV4_ONLY
-static int mdns_query_ipv6(int fd, const ipv6_address_t *ipv6, void (*name_func)(const char *name, void *userdata), void *userdata) {
+int mdns_query_ipv6(int fd, const ipv6_address_t *ipv6, void (*name_func)(const char *name, void *userdata), void *userdata) {
char name[256];
assert(fd >= 0 && ipv6 && name_func);