#ifdef __sun__
#define BSD_COMP		/* for SIOCATMARK on Solaris */
#endif

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/ioctl.h>

#include <err.h>
#include <errno.h>
#include <fcntl.h>
#include <poll.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <netinet/in.h>
#include <arpa/inet.h>

int usage(void);
int getsocket(void);
void oobtest(int);
int pollsock(int);

bool oobinline = false;
bool oobsync = false;
bool delay = false;
bool bytes = false;


int
usage(void)
{
    fprintf(stderr, "usage: oobrecv [oob | inline] [sync] [one] [delay | sleep] ...\n");
    fprintf(stderr, "default is: %s%s%s\n",
	    oobinline ? "inline" : "oob",
	    oobsync ? " sync" : "",
	    delay ? " delay" : "");

    return EXIT_FAILURE;
}


int
main(int argc, char **argv)
{
    int i;

    for (i = 1; i < argc; ++i) {
	if (strcasecmp(argv[i], "oob") == 0)
	    oobinline = false;
	else if (strcasecmp(argv[i], "inline") == 0)
	    oobinline = true;
	else if (strcasecmp(argv[i], "delay") == 0
		 || strcasecmp(argv[i], "sleep") == 0)
	    delay = true;
	else if (strcasecmp(argv[i], "sync") == 0)
	    oobsync = true;
	else if (strcasecmp(argv[i], "one") == 0)
	    bytes = true;
	else
	    return usage();
    }

    oobtest(getsocket());
    return 0;
}


int
getsocket(void)
{
    int s;
    int status;

    s = socket(PF_INET, SOCK_STREAM, 0);
    if (s < 0)
	err(EXIT_FAILURE, "socket");

    struct sockaddr_in sin;
    memset(&sin, 0, sizeof(sin));
#if !defined(__linux__) && !defined(__sun__)
    sin.sin_len = sizeof(sin);
#endif
    sin.sin_family = AF_INET;
    sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    sin.sin_port = htons(12345);

    if (oobinline) {
	int one = 1;
	status = setsockopt(s, SOL_SOCKET, SO_OOBINLINE,
			    (char *)&one, sizeof(one));
	if (status < 0)
	    err(EXIT_FAILURE, "SO_OOBINLINE");
    }

    status = connect(s, (struct sockaddr *)&sin, sizeof(sin));
    if (status < 0)
	err(EXIT_FAILURE, "connect");

    return s;
}


void
oobtest(int s)
{
    char buf[128];
    int status;
    int atmark;

    printf("reading urgent data with %s",
	   oobinline ? "SO_OOBINLINE" : "MSG_OOB");
    if (oobsync)
	printf(" with SIOCATMARK");
    if (delay)
	printf(" after delay");
    printf("\n");

    if (delay)
	sleep(1);

    for (;;) {
	int revents = pollsock(s);
	if (revents == 0)
	    continue;

	if (revents & POLLNVAL) {
	    errno = EBADF;
	    err(EXIT_FAILURE, "poll");
	}

	if (revents & POLLERR) {
	    int sockerr;
	    socklen_t optlen = (socklen_t)sizeof(sockerr);

            status = getsockopt(s, SOL_SOCKET, SO_ERROR,
				(char *)&sockerr, &optlen);
	    if (status < 0)
		err(EXIT_FAILURE, "SO_ERROR");

	    errno = sockerr;
	    err(EXIT_FAILURE, NULL);
	}

	ssize_t nread;

	if (oobsync && (revents & (POLLPRI | POLLRDBAND))) {
	    for (;;) {
		status = ioctl(s, SIOCATMARK, &atmark);
		if (status < 0)
		    err(EXIT_FAILURE, "SIOCATMARK");

		if (atmark) {
		    printf("<at mark>\n");
		    break;
		} else {
		    printf("<not at mark>\n");
		}

		nread = recv(s, buf, sizeof(buf), 0);
		if (nread < 0)
		    err(EXIT_FAILURE, "recv");

		printf("recv() = %zd (reading to mark)\n", nread);
		fwrite(buf, 1, nread, stdout);
		printf("\n");
	    }
	}

	int flags = 0;
	if (!oobinline && (revents & (POLLPRI | POLLRDBAND)))
	    flags |= MSG_OOB;

#if 1
		status = ioctl(s, SIOCATMARK, &atmark);
		if (status < 0)
		    err(EXIT_FAILURE, "SIOCATMARK");

		if (atmark) {
		    printf("<at mark>\n");
		} else {
		    printf("<not at mark>\n");
		}
#endif

	nread = recv(s, buf, bytes ? 1 : sizeof(buf), flags);

	if (nread < 0) {
	    if (flags & MSG_OOB) {
		flags &= ~MSG_OOB;
		warn("recv(MSG_OOB)");
		continue;
	    }
	    else {
		err(EXIT_FAILURE, "recv%s",
		    flags & MSG_OOB ? "(MSG_OOB)" : "");
	    }
	}

	printf("recv(%s) = %zd\n",
	       flags & MSG_OOB ? "MSG_OOB" : "",
	       nread);

	if (nread == 0)
	    return;

	fwrite(buf, 1, nread, stdout);
	printf("\n");
    }
}


int
pollsock(int s)
{
    struct pollfd fds[1];

    fds[0].fd = s;
    fds[0].events = POLLIN | POLLPRI | POLLRDBAND;
    fds[0].revents = 0;

    int nready = poll(fds, 1, -1);
    if (nready < 0)
	err(EXIT_FAILURE, "poll");

    if (nready == 0)
	return 0;

    int revents = fds[0].revents;

    printf("poll: revents = 0x%x", revents);
    if (fds[0].revents != 0) {
	printf(":");
	if (revents & POLLNVAL)   printf(" NVAL");
	if (revents & POLLERR) 	  printf(" ERR");
	if (revents & POLLHUP) 	  printf(" HUP");
	if (revents & POLLIN)  	  printf(" IN");
	if (revents & POLLPRI) 	  printf(" PRI");
	if (revents & POLLRDNORM) printf(" RDNORM");
	if (revents & POLLRDBAND) printf(" RDBAND");
    }
    printf("\n");

    return revents;
}
