/*
 * xdp-server.c -- integration of AF_XDP into nsd
 *
 * Copyright (c) 2024, NLnet Labs. All rights reserved.
 *
 * See LICENSE for the license.
 *
 */

/*
 * Parts inspired by https://github.com/xdp-project/xdp-tutorial
 */

#include "config.h"

#ifdef USE_XDP

#include <assert.h>
#include <errno.h>
#include <netinet/in.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <unistd.h>
#include <linux/limits.h>
#include <sys/mman.h>

#include <sys/poll.h>
#include <sys/resource.h>

/* #include <bpf/bpf.h> */
#include <xdp/xsk.h>
#include <xdp/libxdp.h>
#include <bpf/libbpf.h>

#include <arpa/inet.h>
#include <linux/icmpv6.h>
#include <linux/if_ether.h>
#include <linux/ipv6.h>
#include <linux/ip.h>
#include <linux/udp.h>
#include <net/if.h>

#include <arpa/inet.h>
#include <netdb.h>
#include <ifaddrs.h>
#include <linux/if_link.h>

#include "query.h"
#include "dns.h"
#include "util.h"
#include "xdp-server.h"
#include "xdp-util.h"
#include "nsd.h"

// TODO: make configurable
#define DNS_PORT 53

struct xdp_config {
	__u32 xdp_flags;
	__u32 libxdp_flags;
	__u16 xsk_bind_flags;
};

struct umem_ptr {
	uint64_t addr;
	uint32_t len;
};

static struct umem_ptr umem_ptrs[XDP_RX_BATCH_SIZE];

/*
 * Allocate memory for UMEM and setup rings
 */
static int
xsk_configure_umem(struct xsk_umem_info *umem_info, uint64_t size);

/*
 * Retrieve a UMEM frame address for allocation
 *
 * Returns XDP_INVALID_UMEM_FRAME when there are no free frames available.
 */
static uint64_t xsk_alloc_umem_frame(struct xsk_socket_info *xsk);

/*
 * Bind AF_XDP socket and setup rings
 */
static int xsk_configure_socket(struct xdp_server *xdp,
                                struct xsk_socket_info *xsk_info,
                                struct xsk_umem_info *umem,
                                uint32_t queue_index);

/*
 * Get number of free frames in UMEM
 */
static uint64_t xsk_umem_free_frames(struct xsk_socket_info *xsk);

/*
 * Free a frame in UMEM
 */
static void xsk_free_umem_frame(struct xsk_socket_info *xsk, uint64_t frame);

/*
 * Fill fill ring with as many frames as possible
 */
static void fill_fq(struct xsk_socket_info *xsk);

/*
 * Load eBPF program to forward traffic to our socket
 */
static int load_xdp_program_and_map(struct xdp_server *xdp);

/*
 * Unload eBPF/XDP program
 */
static void unload_xdp_program(struct xdp_server *xdp);

/*
 * Figure out IP addresses to listen to.
 */
static int figure_ip_addresses(struct xdp_server *xdp);

/*
 * Add IP address to allowed destination addresses for incoming packets
 */
static void add_ip_address(struct xdp_server *xdp,
                           struct sockaddr_storage *addr);

/*
 * Check whether destination IPv4 is in allowed IPs list
 */
static int dest_ip_allowed4(struct xdp_server *xdp, struct iphdr *ipv4);

/*
 * Check whether destination IPv6 is in allowed IPs list
 */
static int dest_ip_allowed6(struct xdp_server *xdp, struct ipv6hdr *ipv6);

/*
 * Setup XDP sockets
 */
static int xdp_sockets_init(struct xdp_server *xdp);

/*
 * Cleanup XDP sockets and memory
 */
static void xdp_sockets_cleanup(struct xdp_server *xdp);

/*
 * Allocate a block of shared memory
 */
static void *alloc_shared_mem(size_t len);

/*
 * Collect free frames from completion queue
 */
static void drain_cq(struct xsk_socket_info *xsk);

/*
 * Send outstanding packets and recollect completed frame addresses
 */
static void handle_tx(struct xsk_socket_info *xsk);

/*
 * Process packet and indicate if it should be dropped
 * return 0 or less => drop
 * return greater than 0 => use for tx
 */
static int
process_packet(struct xdp_server *xdp,
               uint8_t *pkt,
               uint32_t *len,
               struct query *query);

static inline void swap_eth(struct ethhdr *eth);
static inline void swap_udp(struct udphdr *udp);
static inline void swap_ipv6(struct ipv6hdr *ipv6);
static inline void swap_ipv4(struct iphdr *ipv4);
static inline void *parse_udp(struct udphdr *udp);
static inline void *parse_ipv6(struct ipv6hdr *ipv6);
static inline void *parse_ipv4(struct iphdr *ipv4);

/*
 * Parse dns message and return new length of dns message
 */
static uint32_t parse_dns(struct nsd* nsd,
                          uint32_t dnslen,
                          struct query *q,
                          sa_family_t ai_family);

/* *************** */
/* Implementations */
/* *************** */

static uint64_t xsk_alloc_umem_frame(struct xsk_socket_info *xsk) {
	uint64_t frame;
	if (xsk->umem->umem_frame_free == 0) {
		return XDP_INVALID_UMEM_FRAME;
	}

	frame = xsk->umem->umem_frame_addr[--xsk->umem->umem_frame_free];
	xsk->umem->umem_frame_addr[xsk->umem->umem_frame_free] =
		XDP_INVALID_UMEM_FRAME;
	return frame;
}

static uint64_t xsk_umem_free_frames(struct xsk_socket_info *xsk) {
	return xsk->umem->umem_frame_free;
}

static void xsk_free_umem_frame(struct xsk_socket_info *xsk, uint64_t frame) {
	assert(xsk->umem->umem_frame_free < XDP_NUM_FRAMES);
	xsk->umem->umem_frame_addr[xsk->umem->umem_frame_free++] = frame;
}

static void fill_fq(struct xsk_socket_info *xsk) {
	uint32_t stock_frames;
	uint32_t idx_fq = 0;
	/* fill the fill ring with as many frames as are available */
	/* get number of spots available in fq */
	stock_frames = xsk_prod_nb_free(&xsk->umem->fq,
	                                (uint32_t) xsk_umem_free_frames(xsk));
	if (stock_frames > 0) {
		/* ignoring prod__reserve return value, because we got stock_frames
		 * from xsk_prod_nb_free(), which are therefore available */
		xsk_ring_prod__reserve(&xsk->umem->fq, stock_frames, &idx_fq);

		for (uint32_t i = 0; i < stock_frames; ++i) {
			/* TODO: handle lack of available frames?
			 * Is not necessary when the total amount of frames exceeds the
			 * total slots available across all queues combined */
			/* uint64_t frame = xsk_alloc_umem_frame(xsk); */
			/* if (frame == XDP_INVALID_UMEM_FRAME) */
			/*     printf("xdp: trying to fill_addr INVALID UMEM FRAME"); */
			*xsk_ring_prod__fill_addr(&xsk->umem->fq, idx_fq++) =
				xsk_alloc_umem_frame(xsk);
		}

		xsk_ring_prod__submit(&xsk->umem->fq, stock_frames);
	}
}

static int load_xdp_program_and_map(struct xdp_server *xdp) {
	struct bpf_map *map;
	char errmsg[512];
	int err, ret;
	/* UNSPEC => let libxdp decide */
	// TODO: put this into a config option as well?
	enum xdp_attach_mode attach_mode = XDP_MODE_UNSPEC;

	DECLARE_LIBXDP_OPTS(bpf_object_open_opts, opts);
	if (xdp->bpf_bpffs_path)
		opts.pin_root_path = xdp->bpf_bpffs_path;

	/* for now our xdp program should contain just one program section */
	// TODO: look at xdp_program__create because it can take a pinned prog
	xdp->bpf_prog = xdp_program__open_file(xdp->bpf_prog_filename, NULL, &opts);

	// conversion should be fine, libxdp errors shouldn't exceed (int),
	// also libxdp_strerr takes int anyway...
	err = (int) libxdp_get_error(xdp->bpf_prog);
	if (err) {
		libxdp_strerror(err, errmsg, sizeof(errmsg));
		log_msg(LOG_ERR, "xdp: could not open xdp program: %s\n", errmsg);
		return err;
	}

	if (xdp->bpf_prog_should_load) {
		/* TODO: I find setting environment variables from within a program
		 * not a good thing to do, but for the meantime this helps... */
		/* This is done to allow unloading the XDP program we load without
		 * needing the SYS_ADMIN capability, and libxdp doesn't allow skipping
		 * the dispatcher through different means. */
		putenv("LIBXDP_SKIP_DISPATCHER=1");
		err = xdp_program__attach(xdp->bpf_prog, (int) xdp->interface_index, attach_mode, 0);
		/* err = xdp_program__attach_single(xdp->bpf_prog, xdp->interface_index, attach_mode); */
		if (err) {
			libxdp_strerror(err, errmsg, sizeof(errmsg));
			log_msg(LOG_ERR, "xdp: could not attach xdp program to interface '%s' : %s\n", 
					xdp->interface_name, errmsg);
			return err;
		}

		xdp->bpf_prog_fd = xdp_program__fd(xdp->bpf_prog);
		xdp->bpf_prog_id = xdp_program__id(xdp->bpf_prog);

		/* We also need to get the file descriptor to the xsks_map */
		map = bpf_object__find_map_by_name(xdp_program__bpf_obj(xdp->bpf_prog), "xsks_map");
		ret = bpf_map__fd(map);
		if (ret < 0) {
			log_msg(LOG_ERR, "xdp: no xsks map found in xdp program: %s\n", strerror(ret));
			return ret;
		}
		xdp->xsk_map_fd = ret;
		xdp->xsk_map = map;
	} else {
		char map_path[PATH_MAX];
		int fd;

		snprintf(map_path, PATH_MAX, "%s/%s", xdp->bpf_bpffs_path, "xsks_map");

		fd = bpf_obj_get(map_path);
		if (fd < 0) {
			log_msg(LOG_ERR, "xdp: could not retrieve xsks_map pin from %s: %s", map_path, strerror(errno));
			return fd;
		}

		map = bpf_object__find_map_by_name(xdp_program__bpf_obj(xdp->bpf_prog), "xsks_map");
		if ((ret = bpf_map__reuse_fd(map, fd))) {
			log_msg(LOG_ERR, "xdp: could not re-use xsks_map: %s\n", strerror(errno));
			return ret;
		}

		xdp->xsk_map_fd = fd;
		xdp->xsk_map = map;
	}

	return 0;
}

static int
xsk_configure_umem(struct xsk_umem_info *umem_info, uint64_t size) {
	int ret;
	struct xsk_umem_config umem_config = {
		.fill_size = XSK_RING_PROD__NUM_DESCS,
		.comp_size = XSK_RING_CONS__NUM_DESCS,
		.frame_size = XDP_FRAME_SIZE,
		.frame_headroom = XSK_UMEM_FRAME_HEADROOM,
		.flags = XSK_UMEM_FLAGS,
	};

	ret = xsk_umem__create(&umem_info->umem, umem_info->buffer, size, &umem_info->fq, &umem_info->cq, &umem_config);
	if (ret) {
		errno = -ret;
		return ret;
	}

	return 0;
}

static int
xsk_configure_socket(struct xdp_server *xdp, struct xsk_socket_info *xsk_info,
                     struct xsk_umem_info *umem, uint32_t queue_index) {
	struct xsk_socket_config xsk_cfg;
	uint32_t idx, reserved;
	int ret;

	struct xdp_config cfg = {
		.xdp_flags = 0,
		.xsk_bind_flags = 0,
		.libxdp_flags = XSK_LIBXDP_FLAGS__INHIBIT_PROG_LOAD,
	};

	uint16_t xsk_bind_flags = XDP_USE_NEED_WAKEUP;
	if (xdp->force_copy) {
		xsk_bind_flags |= XDP_COPY;
	}
	cfg.xsk_bind_flags = xsk_bind_flags;

	xsk_info->umem = umem;
	xsk_cfg.rx_size = XSK_RING_CONS__NUM_DESCS;
	xsk_cfg.tx_size = XSK_RING_PROD__NUM_DESCS;
	xsk_cfg.xdp_flags = cfg.xdp_flags;
	xsk_cfg.bind_flags = cfg.xsk_bind_flags;
	xsk_cfg.libxdp_flags = cfg.libxdp_flags;

	ret = xsk_socket__create(&xsk_info->xsk,
	                         xdp->interface_name,
	                         queue_index,
	                         umem->umem,
	                         &xsk_info->rx,
	                         &xsk_info->tx,
	                         &xsk_cfg);
	if (ret) {
		log_msg(LOG_ERR, "xdp: failed to create xsk_socket");
		goto error_exit;
	}

	ret = xsk_socket__update_xskmap(xsk_info->xsk, xdp->xsk_map_fd);
	if (ret) {
		log_msg(LOG_ERR, "xdp: failed to update xskmap");
		goto error_exit;
	}

	/* Initialize umem frame allocation */
	for (uint32_t i = 0; i < XDP_NUM_FRAMES; ++i) {
		xsk_info->umem->umem_frame_addr[i] = i * XDP_FRAME_SIZE;
	}

	xsk_info->umem->umem_frame_free = XDP_NUM_FRAMES;

	reserved = xsk_ring_prod__reserve(&xsk_info->umem->fq,
	                             XSK_RING_PROD__NUM_DESCS,
	                             &idx);

	if (reserved != XSK_RING_PROD__NUM_DESCS) {
		log_msg(LOG_ERR,
		        "xdp: amount of reserved addr not as expected (is %d)", reserved);
		// "ENOMEM 12 Cannot allocate memory" is the closest to the
		// error that not as much memory was reserved as expected
		ret = -12;
		goto error_exit;
	}

	for (uint32_t i = 0; i < XSK_RING_PROD__NUM_DESCS; ++i) {
		*xsk_ring_prod__fill_addr(&xsk_info->umem->fq, idx++) =
			xsk_alloc_umem_frame(xsk_info);
	}

	xsk_ring_prod__submit(&xsk_info->umem->fq, XSK_RING_PROD__NUM_DESCS);

	return 0;

error_exit:
	errno = -ret;
	return ret;
}

static void *alloc_shared_mem(size_t len) {
	/* MAP_ANONYMOUS memory is initialized with zero */
	return mmap(NULL, len, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0);
}

static int xdp_sockets_init(struct xdp_server *xdp) {
	size_t umems_len = sizeof(struct xsk_umem_info) * xdp->queue_count;
	size_t xsks_len = sizeof(struct xsk_socket_info) * xdp->queue_count;

	xdp->umems = (struct xsk_umem_info *) alloc_shared_mem(umems_len);
	if (xdp->umems == MAP_FAILED) {
		log_msg(LOG_ERR,
		        "xdp: failed to allocate shared memory for umem info: %s",
		        strerror(errno));
		return -1;
	}

	xdp->xsks = (struct xsk_socket_info *) alloc_shared_mem(xsks_len);
	if (xdp->xsks == MAP_FAILED) {
		log_msg(LOG_ERR,
		        "xdp: failed to allocate shared memory for xsk info: %s",
		        strerror(errno));
		return -1;
	}

	for (uint32_t q_idx = 0; q_idx < xdp->queue_count; ++q_idx) {
		/* mmap is supposedly page-aligned, so should be fine */
		xdp->umems[q_idx].buffer = alloc_shared_mem(XDP_BUFFER_SIZE);

		if (xsk_configure_umem(&xdp->umems[q_idx], XDP_BUFFER_SIZE)) {
			log_msg(LOG_ERR, "xdp: cannot create umem: %s", strerror(errno));
			goto out_err_umem;
		}

		if (xsk_configure_socket(xdp, &xdp->xsks[q_idx], &xdp->umems[q_idx],
		                         q_idx)) {
			log_msg(LOG_ERR,
			        "xdp: cannot create AF_XDP socket: %s",
			        strerror(errno));
			goto out_err_xsk;
		}
	}

	return 0;

out_err_xsk:
	for (uint32_t i = 0; i < xdp->queue_count; ++i)
		xsk_umem__delete(xdp->umems[i].umem);

out_err_umem:
	return -1;
}

static void xdp_sockets_cleanup(struct xdp_server *xdp) {
	for (uint32_t i = 0; i < xdp->queue_count; ++i) {
		xsk_socket__delete(xdp->xsks[i].xsk);
		xsk_umem__delete(xdp->umems[i].umem);
	}
}

int xdp_server_init(struct xdp_server *xdp) {
	struct rlimit rlim = {RLIM_INFINITY, RLIM_INFINITY};

	/* check if interface name exists */
	xdp->interface_index = if_nametoindex(xdp->interface_name);
	if (xdp->interface_index == 0) {
		log_msg(LOG_ERR, "xdp: configured xdp-interface (%s) is unknown: %s",
		        xdp->interface_name, strerror(errno));
		return -1;
	}

	/* (optionally) load xdp program and (definitely) set xsks_map_fd */
	if (load_xdp_program_and_map(xdp)) {
		log_msg(LOG_ERR, "xdp: failed to load/pin xdp program/map");
		return -1;
	}

	/* if we don't do set rlimit, libbpf does it */
	/* this either has to be done before privilege drop or
	 * requires CAP_SYS_RESOURCE */
	if (setrlimit(RLIMIT_MEMLOCK, &rlim)) {
		log_msg(LOG_ERR, "xdp: cannot adjust rlimit (RLIMIT_MEMLOCK): \"%s\"\n",
			strerror(errno));
		return -1;
	}

	if (xdp_sockets_init(xdp))
		return -1;

	for (int i = 0; i < XDP_RX_BATCH_SIZE; ++i) {
		umem_ptrs[i].addr = XDP_INVALID_UMEM_FRAME;
		umem_ptrs[i].len = 0;
	}

	if (!xdp->ip_addresses)
		figure_ip_addresses(xdp);

	return 0;
}

void xdp_server_cleanup(struct xdp_server *xdp) {
	xdp_sockets_cleanup(xdp);

	/* only unpin if we loaded the program */
	if (xdp->bpf_prog_should_load) {
		if (xdp->xsk_map && bpf_map__is_pinned(xdp->xsk_map)) {
			if (bpf_map__unpin(xdp->xsk_map, NULL)) {
				/* We currently ship an XDP program that doesn't pin the map. So
				 * if this error happens, it is because the user specified their
				 * custom XDP program to load by NSD. Therefore they should know
				 * about the pinned map and be able to unlink it themselves.
				 */
				log_msg(LOG_ERR, "xdp: failed to unpin bpf map during cleanup: \"%s\". "
				        "This is usually ok, but you need to unpin the map yourself. "
				        "This can usually be fixed by executing chmod o+wx %s\n",
				        strerror(errno), xdp->bpf_bpffs_path);
			}
		}

		unload_xdp_program(xdp);
	}
}

static void unload_xdp_program(struct xdp_server *xdp) {
	DECLARE_LIBBPF_OPTS(bpf_xdp_attach_opts, bpf_opts,
	                    .old_prog_fd = xdp->bpf_prog_fd);

	log_msg(LOG_INFO, "xdp: detaching xdp program %u from %s\n",
			xdp->bpf_prog_id, xdp->interface_name);

	if (bpf_xdp_detach((int) xdp->interface_index, 0, &bpf_opts))
		log_msg(LOG_ERR, "xdp: failed to detach xdp program: %s\n",
		        strerror(errno));
}

static int dest_ip_allowed6(struct xdp_server *xdp, struct ipv6hdr *ipv6) {
	struct xdp_ip_address *ip = xdp->ip_addresses;
	if (!ip)
		// no IPs available, allowing all
		return 1;

	while (ip) {
		if (ip->addr.ss_family == AF_INET6 &&
		    !memcmp(&(((struct sockaddr_in6 *) &ip->addr)->sin6_addr),
		            &ipv6->daddr,
		            sizeof(struct in6_addr)))
			return 1;
		ip = ip->next;
	}

	return 0;
}

static int dest_ip_allowed4(struct xdp_server *xdp, struct iphdr *ipv4) {
	struct xdp_ip_address *ip = xdp->ip_addresses;
	if (!ip)
		// no IPs available, allowing all
		return 1;

	while (ip) {
		if (ip->addr.ss_family == AF_INET &&
		    ipv4->daddr == ((struct sockaddr_in *) &ip->addr)->sin_addr.s_addr)
			return 1;
		ip = ip->next;
	}

	return 0;
}

static void
add_ip_address(struct xdp_server *xdp, struct sockaddr_storage *addr) {
	struct xdp_ip_address *ip = xdp->ip_addresses;
	if (!ip) {
		xdp->ip_addresses = region_alloc_zero(xdp->region,
		                                      sizeof(struct xdp_ip_address));
		ip = xdp->ip_addresses;
	} else {
		while (ip->next)
			ip = ip->next;

		ip->next = region_alloc_zero(xdp->region,
		                             sizeof(struct xdp_ip_address));
		ip = ip->next;
	}

	memcpy(&ip->addr, addr, sizeof(struct sockaddr_storage));
}

static int figure_ip_addresses(struct xdp_server *xdp) {
	// TODO: if using VLANs, also find appropriate IP addresses?
	struct ifaddrs *ifaddr;
	int family, ret = 0;

	if (getifaddrs(&ifaddr) == -1) {
		log_msg(LOG_ERR, "xdp: couldn't determine local IP addresses. "
		                 "Serving all IP addresses now");
		return -1;
	}

	for (struct ifaddrs *ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
		if (ifa->ifa_addr == NULL)
			continue;

		if (strcmp(ifa->ifa_name, xdp->interface_name))
			continue;

		family = ifa->ifa_addr->sa_family;

		switch (family) {
			default:
				continue;
			case AF_INET:
			case AF_INET6:
				add_ip_address(xdp, (struct sockaddr_storage *) ifa->ifa_addr);
		}
	}

	freeifaddrs(ifaddr);
	return ret;
}

static inline void swap_eth(struct ethhdr *eth) {
	uint8_t tmp_mac[ETH_ALEN];
	memcpy(tmp_mac, eth->h_dest, ETH_ALEN);
	memcpy(eth->h_dest, eth->h_source, ETH_ALEN);
	memcpy(eth->h_source, tmp_mac, ETH_ALEN);
}

static inline void swap_udp(struct udphdr *udp) {
	uint16_t tmp_port; /* not touching endianness */
	tmp_port = udp->source;
	udp->source = udp->dest;
	udp->dest = tmp_port;
}

static inline void swap_ipv6(struct ipv6hdr *ipv6) {
	struct in6_addr tmp_ip;
	memcpy(&tmp_ip, &ipv6->saddr, sizeof(tmp_ip));
	memcpy(&ipv6->saddr, &ipv6->daddr, sizeof(tmp_ip));
	memcpy(&ipv6->daddr, &tmp_ip, sizeof(tmp_ip));
}

static inline void swap_ipv4(struct iphdr *ipv4) {
	struct in_addr tmp_ip;
	memcpy(&tmp_ip, &ipv4->saddr, sizeof(tmp_ip));
	memcpy(&ipv4->saddr, &ipv4->daddr, sizeof(tmp_ip));
	memcpy(&ipv4->daddr, &tmp_ip, sizeof(tmp_ip));
}

static inline void *parse_udp(struct udphdr *udp) {
	if (ntohs(udp->dest) != DNS_PORT)
		return NULL;

	return (void *)(udp + 1);
}

static inline void *parse_ipv6(struct ipv6hdr *ipv6) {
	if (ipv6->nexthdr != IPPROTO_UDP)
		return NULL;

	return (void *)(ipv6 + 1);
}

static inline void *parse_ipv4(struct iphdr *ipv4) {
	if (ipv4->protocol != IPPROTO_UDP)
		return NULL;

	return (void *)(ipv4 + 1);
}

static uint32_t parse_dns(struct nsd* nsd, uint32_t dnslen,
                          struct query *q, sa_family_t ai_family) {
	/* TODO: implement DNSTAP, PROXYv2, ...? */
	uint32_t now = 0;

	/* set the size of the dns message and move position to start */
	buffer_skip(q->packet, dnslen);
	buffer_flip(q->packet);

	if (query_process(q, nsd, &now) != QUERY_DISCARDED) {
		if (RCODE(q->packet) == RCODE_OK && !AA(q->packet)) {
			STATUP(nsd, nona);
			ZTATUP(nsd, q->zone, nona);
		}

#ifdef USE_ZONE_STATS
		if (ai_family == AF_INET) {
			ZTATUP(nsd, q->zone, qudp);
		} else if (ai_family == AF_INET6) {
			ZTATUP(nsd, q->zone, qudp6);
		}
#endif /* USE_ZONE_STATS */

		query_add_optional(q, nsd, &now);

		buffer_flip(q->packet);

#ifdef BIND8_STATS
			/* Account the rcode & TC... */
			STATUP2(nsd, rcode, RCODE(q->packet));
			ZTATUP2(nsd, q->zone, rcode, RCODE(q->packet));
			if (TC(q->packet)) {
				STATUP(nsd, truncated);
				ZTATUP(nsd, q->zone, truncated);
			}
#endif /* BIND8_STATS */

		/* return new dns message length */
		return (uint32_t) buffer_remaining(q->packet);
	} else {
		query_reset(q, UDP_MAX_MESSAGE_LEN, 0);
		STATUP(nsd, dropped);
		ZTATUP(nsd, q->zone, dropped);
		return 0;
	}
}

static int
process_packet(struct xdp_server *xdp, uint8_t *pkt,
               uint32_t *len, struct query *query) {
	/* log_msg(LOG_INFO, "xdp: received packet with len %d", *len); */

	uint32_t dnslen = *len;
	uint32_t data_before_dnshdr_len = 0;

	struct ethhdr *eth = (struct ethhdr *)pkt;
	struct ipv6hdr *ipv6 = NULL;
	struct iphdr *ipv4 = NULL;
	struct udphdr *udp = NULL;
	void *dnshdr = NULL;

	/* doing the check here, so that the packet/frame is large enough to contain
	 * at least an ethernet header, an ipv4 header (ipv6 header is larger), and
	 * a udp header.
	 */
	if (*len < (sizeof(*eth) + sizeof(struct iphdr) + sizeof(*udp)))
		return -1;

	data_before_dnshdr_len = sizeof(*eth) + sizeof(*udp);

	switch (ntohs(eth->h_proto)) {
	case ETH_P_IPV6: {
		ipv6 = (struct ipv6hdr *)(eth + 1);

		if (*len < (sizeof(*eth) + sizeof(*ipv6) + sizeof(*udp)))
			return -2;
		if (!(udp = parse_ipv6(ipv6)))
			return -3;

		dnslen -= (uint32_t) (sizeof(*eth) + sizeof(*ipv6) + sizeof(*udp));
		data_before_dnshdr_len += sizeof(*ipv6);

		if (!dest_ip_allowed6(xdp, ipv6))
			return -4;

		break;
	} case ETH_P_IP: {
		ipv4 = (struct iphdr *)(eth + 1);

		if (!(udp = parse_ipv4(ipv4)))
			return -5;

		dnslen -= (uint32_t) (sizeof(*eth) + sizeof(*ipv4) + sizeof(*udp));
		data_before_dnshdr_len += sizeof(*ipv4);

		if (!dest_ip_allowed4(xdp, ipv4))
			return -6;

		break;
	}

	/* TODO: vlan? */
	/* case ETH_P_8021AD: case ETH_P_8021Q: */
	/*     if (*len < (sizeof(*eth) + sizeof(*vlan))) */
	/*         break; */
	default:
		return -7;
	}

	if (!(dnshdr = parse_udp(udp)))
		return -8;

	query_set_buffer_data(query, dnshdr, XDP_FRAME_SIZE - data_before_dnshdr_len);

	if(ipv6) {
#ifdef INET6
		struct sockaddr_in6* sock6 = (struct sockaddr_in6*)&query->remote_addr;
		sock6->sin6_family = AF_INET6;
		sock6->sin6_port = udp->dest;
		sock6->sin6_flowinfo = 0;
		sock6->sin6_scope_id = 0;
		memcpy(&sock6->sin6_addr, &ipv6->saddr, sizeof(ipv6->saddr));
		query->remote_addrlen = (socklen_t)sizeof(struct sockaddr_in6);
#else
		return 0; /* no inet6 no network */
#endif /* INET6 */
#ifdef BIND8_STATS
		STATUP(xdp->nsd, qudp6);
#endif /* BIND8_STATS */
	} else {
		struct sockaddr_in* sock4 = (struct sockaddr_in*)&query->remote_addr;
		sock4->sin_family = AF_INET;
		sock4->sin_port = udp->dest;
		sock4->sin_addr.s_addr = ipv4->saddr;
		query->remote_addrlen = (socklen_t)sizeof(struct sockaddr_in);
#ifdef BIND8_STATS
		STATUP(xdp->nsd, qudp);
#endif /* BIND8_STATS */
	}

	query->client_addr    = query->remote_addr;
	query->client_addrlen = query->remote_addrlen;
	query->is_proxied = 0;

	dnslen = parse_dns(xdp->nsd, dnslen, query, query->remote_addr.ss_family);
	if (!dnslen) {
		return -9;
	}

	// Not verifying the packet length (that it fits in an IP packet), as
	// parse_dns truncates too long response messages.
	udp->len = htons((uint16_t) (sizeof(*udp) + dnslen));

	swap_eth(eth);
	swap_udp(udp);

	if (ipv4) {
		__be16 ipv4_old_len = ipv4->tot_len;
		swap_ipv4(ipv4);
		ipv4->tot_len = htons(sizeof(*ipv4)) + udp->len;
		csum16_replace(&ipv4->check, ipv4_old_len, ipv4->tot_len);
		udp->check = calc_csum_udp4(udp, ipv4);
	} else if (ipv6) {
		swap_ipv6(ipv6);
		ipv6->payload_len = udp->len;
		udp->check = calc_csum_udp6(udp, ipv6);
	} else {
		log_msg(LOG_ERR, "xdp: we forgot to implement something... oops");
		return 0;
	}

	/* log_msg(LOG_INFO, "xdp: done with processing the packet"); */

	*len = data_before_dnshdr_len + dnslen;
	return 1;
}

void xdp_handle_recv_and_send(struct xdp_server *xdp) {
	struct xsk_socket_info *xsk = &xdp->xsks[xdp->queue_index];
	unsigned int recvd, i, reserved, to_send = 0;
	uint32_t idx_rx = 0;
	uint32_t tx_idx = 0;
	int ret;

	recvd = xsk_ring_cons__peek(&xsk->rx, XDP_RX_BATCH_SIZE, &idx_rx);
	if (!recvd) {
		/* no data available */
		return;
	}

	fill_fq(xsk);

	/* Process received packets */
	for (i = 0; i < recvd; ++i) {
		uint64_t addr = xsk_ring_cons__rx_desc(&xsk->rx, idx_rx)->addr;
		uint32_t len = xsk_ring_cons__rx_desc(&xsk->rx, idx_rx++)->len;

		uint8_t *pkt = xsk_umem__get_data(xsk->umem->buffer, addr);
		if ((ret = process_packet(xdp, pkt, &len, xdp->queries[i])) <= 0) {
			/* drop packet */
			xsk_free_umem_frame(xsk, addr);
		} else {
			umem_ptrs[to_send].addr = addr;
			umem_ptrs[to_send].len = len;
			++to_send;
		}
		/* we can reset the query directly after each packet processing,
		 * because the reset does not delete the underlying buffer/data.
		 * However, if we, in future, need to access data from the query
		 * struct when sending the answer, this needs to change.
		 * This also means, that currently a single query instance (and
		 * not an array) would suffice for this implementation. */
		query_reset(xdp->queries[i], UDP_MAX_MESSAGE_LEN, 0);

		/* xsk->stats.rx_bytes += len; */
	}

	xsk_ring_cons__release(&xsk->rx, recvd);
	/* xsk->stats.rx_packets += rcvd; */

	/* Process sending packets */

	/* TODO: at least send as many packets as slots are available */
	reserved = xsk_ring_prod__reserve(&xsk->tx, to_send, &tx_idx);
	// if we can't reserve to_send frames, we'll get 0 frames, so
	// no need to "un-reserve"
	if (reserved != to_send) {
		// not enough tx slots available, drop packets
		log_msg(LOG_ERR, "xdp: not enough TX frames available, dropping "
		        "whole batch");
		for (i = 0; i < to_send; ++i) {
			xsk_free_umem_frame(xsk, umem_ptrs[i].addr);
			umem_ptrs[i].addr = XDP_INVALID_UMEM_FRAME;
			umem_ptrs[i].len = 0;
		}
#ifdef BIND8_STATS
		xdp->nsd->st->txerr += to_send;
#endif /* BIND8_STATS */
		to_send = 0;
	}

	for (i = 0; i < to_send; ++i) {
		uint64_t addr = umem_ptrs[i].addr;
		uint32_t len = umem_ptrs[i].len;
		xsk_ring_prod__tx_desc(&xsk->tx, tx_idx)->addr = addr;
		xsk_ring_prod__tx_desc(&xsk->tx, tx_idx)->len = len;
		tx_idx++;
		xsk->outstanding_tx++;
		umem_ptrs[i].addr = XDP_INVALID_UMEM_FRAME;
		umem_ptrs[i].len = 0;
	}

	xsk_ring_prod__submit(&xsk->tx, to_send);

	/* wake up kernel for tx if needed and collect completed tx buffers */
	handle_tx(xsk);
	/* TODO: maybe call fill_fq(xsk) here as well? */
}

static void drain_cq(struct xsk_socket_info *xsk) {
	uint32_t completed, idx_cq;

	/* free completed TX buffers */
	completed = xsk_ring_cons__peek(&xsk->umem->cq,
	                                XSK_RING_CONS__NUM_DESCS,
	                                &idx_cq);

	if (completed > 0) {
		for (uint32_t i = 0; i < completed; i++) {
			xsk_free_umem_frame(xsk, *xsk_ring_cons__comp_addr(&xsk->umem->cq,
			                                                   idx_cq++));
		}

		xsk_ring_cons__release(&xsk->umem->cq, completed);
		xsk->outstanding_tx -= completed < xsk->outstanding_tx ?
		                       completed : xsk->outstanding_tx;
	}
}

static void handle_tx(struct xsk_socket_info *xsk) {
	if (!xsk->outstanding_tx)
		return;

	if (xsk_ring_prod__needs_wakeup(&xsk->tx))
		sendto(xsk_socket__fd(xsk->xsk), NULL, 0, MSG_DONTWAIT, NULL, 0);

	drain_cq(xsk);

	// Update TX-queue pointers
	// This is not needed, because prod__reserve calls this function too,
	// and therefore, if not enough frames are free on the cached pointers,
	// it will update the real pointers.
	/* xsk_prod_nb_free(&xsk->tx, XSK_RING_PROD__NUM_DESCS/4); */
}

#endif /* USE_XDP */
