/* $NetBSD$ */

/*-
 * Copyright (c) 2011 The NetBSD Foundation, Inc.
 * All rights reserved.
 *
 * This code is derived from software contributed to The NetBSD Foundation
 * by Cherry G. Mathew <cherry@NetBSD.org>
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <sys/cdefs.h>

__RCSID("$NetBSD$");

#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <pthread.h>
#include <sched.h>
#include <semaphore.h>
#include <setjmp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <sys/param.h>
#include <sys/sysctl.h>

#include <atf-c.h>

/* Barrier stuff */
/* We use a semaphore for barriers because it works for fork(2) as
 * well as pthread(3)
 */
struct barrier {
	char *semfile; /* File name of the semaphore */
	sem_t *lock; /* operates as a mutex */
	volatile size_t stile;
};


static void
vm_barrier_init(struct barrier *bt, size_t ncpus)
{
	assert(bt != NULL);

	char _semfile[] = "/semXXXX", *semfile = _semfile;
//	semfile = mktemp(semfile);
	if (semfile == NULL) {
		fprintf(stderr, "Unable to get unique filename for semaphore\n");
		abort();
	}
	bt->lock = sem_open(semfile, O_CREAT, 0600, 1);

	if (bt->lock == SEM_FAILED) {
		fprintf(stderr, "Unable to open semaphore\n");
		perror("sem_open():");
		abort();
	}

	bt->semfile = malloc(sizeof semfile);
	if (bt->semfile == NULL) {
		fprintf(stderr, "Unable to malloc() for filename\n");
		sem_close(bt->lock);
		sem_unlink(semfile);
		abort();
	}

	strncpy(bt->semfile, _semfile, sizeof _semfile);
	bt->stile = ncpus;
}

static void
vm_barrier_destroy(struct barrier *bt)
{
	assert(bt != NULL);
	
	bt->stile = 0;
	sem_close(bt->lock);
	sem_unlink(bt->semfile);
	free(bt->semfile);
}

static void
vm_barrier_hold(struct barrier *bt)
{
	size_t stile;

	assert(bt != NULL);
	assert(bt->lock != NULL);
	assert(bt->stile != 0);

	printf("a) stile == %zd\n", bt->stile);
	while (sem_trywait(bt->lock)) assert(errno == EAGAIN); /* spinwait */
	bt->stile--;
	sem_post(bt->lock);
	printf("b) stile == %zd\n", bt->stile);
	do {
		while (sem_trywait(bt->lock)) /* spinwait */
			assert(errno == EAGAIN);
		stile = bt->stile;
		sem_post(bt->lock);
	} while (stile);
	printf("c) stile == %zd\n", stile);
}


/* 
 * The goal of these tests is to stress the kernel pmap
 * implementation, from userspace. 
 */

/* 
 * Thread thrash: This test fires off one thread per CPU. 
 * Each thread makes synchronised, interleaved data accesses to a
 * shared, locked page of memory. Each thread has a unique mapping to
 * this page, obtained via mmap(). The mappings to the shared page are
 * torn down and created afresh every time, in order to exercise the
 * pmap routines. (XXX: break down this test into smaller tests that
 * exercise identifiable areas of pmap.)
 *
 * This test only operates on a single pmap.
 */


static void
thrash(void *arg)
{
	assert(arg != NULL);
	int tid = *(int *)arg;

	/* Bind threads to given cpu */
	printf("I am thread #%d\n", tid);
	return;
}

static cpuid_t
getcpus(void)
{
	int ncpu;
	static int mib[2] = { CTL_HW, HW_NCPU };
	size_t len;


	len = sizeof ncpu;
	if (sysctl(mib, __arraycount(mib), &ncpu, &len, NULL, 0) == -1) {
		return 0;
	}

	return ncpu;
}

/*
 * Set the fault handler of the current process to fault_routine()
 * To restore the default handler, use prep_fault(NULL);
 * Returns the fault handler that has been set.
 */

static void *
prep_fault(void *fault_routine)
{
	return NULL; /* XXX: */
}

/* Thread wrappers */
/* Quick note on thread "id". Since we assume one thread per cpu, the
 * cpuid is used in place of the thread "id" for all practical
 * purposes.
 */

static jmp_buf sequel[MAXCPUS];

/* Fault handler, to test for legitimate page faults. */
static void
thread_pagefault(void)
{
	longjmp(sequel[0 /* XXX */], 1);
	fprintf(stderr, "pagefault handler did not longjmp() !");
	abort();
}

struct thread_arg {
	void (*func)(void *);
	void (*abortf)(void *);
	void *arg;
};

struct thread_ctx {
	cpuid_t cid; /* The cpu number we are running on */
	pthread_t pth;
	cpuset_t *cset;
	struct barrier init_bar;
	struct thread_arg ctx;
};

/* Can only be called from own thread */

static void
thread_exit(struct thread_ctx *t)
{
	assert(t != NULL);
	assert(pthread_equal(t->pth, pthread_self()));

	vm_barrier_destroy(&t->init_bar);
	cpuset_destroy(t->cset);
	free(t);
	pthread_exit(NULL);
}

static inline bool
thread_equal(struct thread_ctx *t1, struct thread_ctx *t2)
{
	return (t1 == t2);
}

/* 
 * Same as thread_exit, but calls abort callback, if registered,
 * before exiting
 */
static void
thread_abort(struct thread_ctx *t)
{
	assert(t != NULL);
	assert(pthread_equal(t->pth, pthread_self()));

	if (t->ctx.abortf != NULL) {
		t->ctx.abortf(t->ctx.arg);
	}

	vm_barrier_destroy(&t->init_bar);
	cpuset_destroy(t->cset);
	free(t);
	pthread_exit(NULL);
}

static void *
setjmp_tramp(void *arg)
{
	assert(arg != NULL);

	pthread_t pth;
	struct thread_ctx *t = arg;

	pth = pthread_self();

	printf("child addr of t->pth == %p\n", &t->pth);
	printf("child cid == %zd\n", t->cid);
	printf("child pth == %p\n", t->pth);
	sleep(1);
	vm_barrier_hold(&t->init_bar); /* Sync with thread_spawn */
	printf("child pth after  == %p\n", t->pth);
	printf("child pth self after  == %p\n", pth);
	if (!pthread_equal(pth, t->pth)) {
		printf("not the right child\n");
		while(1);
	}

	if (setjmp(sequel[t->cid])) {
		/* 
		 * got here via longjmp() from fault
		 * routine.
		 */

		printf("caught exception\n");
		prep_fault(NULL); /* XXX: reset exception handler */
		thread_abort(t);
	}
	t->ctx.func(t->ctx.arg);
	thread_exit(t);
	return NULL;
}


static struct thread_ctx *
thread_spawn(cpuid_t cid,  /* cpu number */
	     void (*func)(void *),
	     void *arg,
	     void (*abortf)(void *))
{
	struct thread_ctx *t;
	cpuset_t *cpuset;

	assert(func != NULL);
	assert(cid <= MAXCPUS);

	t = (struct thread_ctx *) malloc(sizeof *t);
	if (t == NULL) {
		return NULL;
	}

	cpuset = cpuset_create();

	if (cpuset == NULL) {
		printf("Could not create cpuset\n");
		free(t);
		return NULL;
	}

	if (cpuset_set(cid, cpuset) == -1) {
		printf("Could not set cpuset affinity to cpu%lu \n", cid);
		cpuset_destroy(cpuset);
		free(t);
		return NULL;
	}

	t->cset = cpuset;
	t->cid = cid;
	t->ctx.func = func;
	t->ctx.arg = arg;
	if (abortf != NULL) {
		t->ctx.abortf = abortf;
	}

	vm_barrier_init(&t->init_bar, 2);

	printf("creating new thread for func: %p\n", t->ctx.func);

	printf("addr of t->pth == %p\n", &t->pth);
	if (pthread_create(&t->pth, NULL,
			   setjmp_tramp, t)) {
		printf("error creating thread \n");
		free(t);
		return NULL;
	}
	printf("parent cid == %zd\n", t->cid);
	printf("parent pth == %p\n", t->pth);
	vm_barrier_hold(&t->init_bar); /* Sync with setjmp_tramp() */

	/* Set affinity */

	if (pthread_setaffinity_np(t->pth, cpuset_size(t->cset), t->cset)) {
		printf("error binding thread to CPU %lu\n", 
		       t->cid);
		/* XXX: "destroy" the thread */
		free(t);
		return NULL;

	}

	return t;
}

/* 
 * This function reaps the context memory, not thread_wait(); 
 * This makes it compulsory to use this function from the controlling
 * thread, to make sure memory is not leaked.
 *
 * This is slightly lame, but we're a testing framework, not a
 * threading library.
 */

static int
thread_wait(struct thread_ctx *ctx)
{

	int error;
	pthread_t pth;

	assert(ctx != NULL);

	pth = ctx->pth;
	error = pthread_join(pth, NULL);

	/* ctx is free()d by the thread on thread_exit() */

	return error;
}

struct tt {
	int tid;
	pthread_t pth;
};

static void *
thread(void *arg)
{
	struct tt *ttp = arg;
	int tid = ttp->tid;
	pthread_t pth = ttp->pth;


	printf("I am thread %d\n", tid);

	if (pthread_equal(pthread_self(), pth)) {
		printf("pthread_self() matches\n");
	}

	sleep(3);
	pthread_exit(NULL);
}

static struct tt *
spawn(void)
{
	static int tid = 0;
	struct tt *ttp;

	printf("spawn entered at tid == %d\n", tid);
	ttp = malloc(sizeof *ttp);
	if (ttp == NULL) {
		fprintf(stderr, "malloc() failed\n");
		abort();
	}

	ttp->tid = tid;
	pthread_create(&ttp->pth, NULL, thread, &ttp->tid);
	printf("spawn finished at tid == %d\n", tid);

	tid++;
	return ttp;
}


ATF_TC(test_thread);
ATF_TC_HEAD(test_thread, tc)
{
	atf_tc_set_md_var(tc, "descr",
			  "test pthreads");
}
ATF_TC_BODY(test_thread, tc)
{

	pthread_join(spawn()->pth, NULL);
	pthread_join(spawn()->pth, NULL);
}

ATF_TC(thread_thrash);
ATF_TC_HEAD(thread_thrash, tc)
{
	atf_tc_set_md_var(tc, "descr",
			  "Thrash the TLB from within a single Address Space");
}

ATF_TC_BODY(thread_thrash, tc)
{
	cpuid_t cpuno, i;
	struct thread_ctx *t[MAXCPUS];

	/* 1) Detect no. of cpus via cpuset(3) */
	cpuno = getcpus();

	printf("Detected %lu cpus\n", cpuno);
	ATF_REQUIRE(cpuno > 0);

	/* 2) Fire off threads */
	printf("new threads\n");
	(void) thread_pagefault;
	for (i = 0; i < cpuno; i++) {
		t[i] = thread_spawn(i, thrash, NULL, NULL);
		if (t[i] == NULL) {
			printf("thread spawn failed for cpu%lu\n", i);
		}

		/* XXX: destroy the other threads ? */
		ATF_REQUIRE(t[i] != NULL);
	}

	/* Wait for threads to join */
	for (i = 0; i < cpuno; i++) {
		thread_wait(t[i]);
		printf("joined\n");
	}

}

ATF_TP_ADD_TCS(tp)
{
	ATF_TP_ADD_TC(tp, test_thread);
	ATF_TP_ADD_TC(tp, thread_thrash);
	return atf_no_error();
}
