/* 
 *  hxen_ioctl.c
 *  hxen
 *
 *  Copyright 2009 Citrix Systems, Inc. All rights reserved.
 * 
 */

#include "hxen.h"
#include "hxen_call.h"

#include <ntddk.h>
#include <xen/errno.h>
#include <xen/types.h>

#include <hxen_ioctl.h>

#define KXEN_DEFINE_SYMBOLS_PROTO
#include <hxen_link.h>
KXEN_PROTOTYPES(extern)

#if 0
#define IOCTL_TRACE(fmt, ...) dprintk(fmt, __VA_ARGS__)
#else
#define IOCTL_TRACE(fmt, ...)
#endif

static int hxen_mode;
#define	KXEN_MODE_IDLE		0
#define	KXEN_MODE_LOADED	1
#define	KXEN_MODE_INITIALIZED	2

#define IOCTL_FAILURE(s, r, fmt, ...) do {	\
	IoStatus->Status = s;			\
	ret = r;				\
	fail_msg(fmt, __VA_ARGS__);		\
	ioctl_return = FALSE;			\
    } while (0)
    
#define SET_KXEN_MODE(m)			\
    hxen_mode = (m);
#define CHECK_KXEN_MODE(m, id) do {					\
	if (hxen_mode < (m)) {						\
	    IOCTL_FAILURE(STATUS_UNSUCCESSFUL, EINVAL,			\
			  "hxen_ioctl(" id ") invalid sequence");	\
	    goto out;							\
	}								\
    } while (0)
#define CHECK_KXEN_MODE_NOT(m, id) do {					\
	if (hxen_mode >= (m)) {						\
	    IOCTL_FAILURE(STATUS_UNSUCCESSFUL, EINVAL,			\
			  "hxen_ioctl(" id ") invalid sequence");	\
	    goto out;							\
	}								\
    } while (0)

#define KXEN_OP(name, fn, arg, ...) do {				\
	IOCTL_TRACE("hxen_ioctl(" name ", %p, %x)\n", InputBuffer,	\
		    InputBufferLength);					\
	CHECK_KXEN_MODE(KXEN_MODE_INITIALIZED, name);			\
	if (InputBufferLength < sizeof(arg) || InputBuffer == NULL) {	\
	    IOCTL_FAILURE(STATUS_UNSUCCESSFUL, EINVAL,			\
			  "hxen_ioctl(" name ") input arguments");	\
	    goto out;							\
	}								\
	hxen_cpu_pin_dom0();						\
	try {								\
	    ret = fn((arg *)InputBuffer, __VA_ARGS__);			\
	} except (KXEN_EXCEPTION_EXECUTE_HANDLER) {			\
	    IOCTL_FAILURE(GetExceptionCode(), EINVAL,			\
			  "hxen_ioctl(" name ") exception: 0x%08X",	\
			  IoStatus->Status);				\
	}								\
	hxen_cpu_unpin();						\
    } while (0)

#define KXEN_DOM0_CALL(name, fn, arg) do {				\
	DECLARE_EXCEPTION_REGISTRATION_RECORD(hxen_rec);		\
	IOCTL_TRACE("hxen_ioctl(" name ", %p, %x)\n", InputBuffer,	\
		    InputBufferLength);					\
	CHECK_KXEN_MODE(KXEN_MODE_INITIALIZED, name);			\
	if (InputBufferLength < sizeof(arg) || InputBuffer == NULL) {	\
	    IOCTL_FAILURE(STATUS_UNSUCCESSFUL, EINVAL,			\
			  "hxen_ioctl(" name ") input arguments");	\
	    goto out;							\
	}								\
	hxen_cpu_pin_dom0();						\
	HOOK_EXCEPTION_REGISTRATION_RECORD(hxen_rec,			\
					   hxen_info->ki_dom0_current); \
	try {								\
	    ret = fn(InputBuffer);					\
	} except (KXEN_EXCEPTION_EXECUTE_HANDLER) {			\
	    IOCTL_FAILURE(GetExceptionCode(), EINVAL,			\
			  "hxen_ioctl(" name ") exception: 0x%08X",	\
			  IoStatus->Status);				\
	}								\
	UNHOOK_EXCEPTION_REGISTRATION_RECORD(hxen_rec);			\
	hxen_cpu_unpin();						\
    } while (0)

BOOLEAN
hxen_ioctl(
    __in        FILE_OBJECT        *FileObject,
    __in        BOOLEAN             Wait,
    __in_opt    VOID               *InputBuffer,
    __in        ULONG               InputBufferLength,
    __out_opt   VOID               *OutputBuffer,
    __in        ULONG               OutputBufferLength,
    __in        ULONG               IoControlCode,
    __out       IO_STATUS_BLOCK    *IoStatus,
    __in        DEVICE_OBJECT      *DeviceObject
    )
{
    struct device_extension *devext;
    int func;
    int ret = EINVAL;
    BOOLEAN ioctl_return = TRUE;

    IoStatus->Status = STATUS_SUCCESS;
    IoStatus->Information = 0;

    devext = DeviceObject->DeviceExtension;

    func = FUNCTION_FROM_CTL_CODE(IoControlCode);

    if (func & KXEN_FLAG_INBUFFER) {
    }

    switch (IoControlCode) {
    case KXENVERSION:
	IOCTL_TRACE("hxen_ioctl(KXENVERSION, %p, %x)\n", OutputBuffer,
		    OutputBufferLength);
	if (OutputBufferLength < sizeof(struct hxen_version_desc) ||
	    OutputBuffer == NULL) {
	    IOCTL_FAILURE(STATUS_UNSUCCESSFUL, EINVAL,
			  "hxen_ioctl(KXENVERSION) output arguments");
	    break;
	}
	try {
	    ret = hxen_version((struct hxen_version_desc *)OutputBuffer);
	} except (KXEN_EXCEPTION_EXECUTE_HANDLER) {
	    IOCTL_FAILURE(GetExceptionCode(), EINVAL,
			  "hxen_ioctl(KXENVERSION) exception: 0x%08X",
			  IoStatus->Status);
	}
	break;
    case KXENLOAD:
	IOCTL_TRACE("hxen_ioctl(KXENLOAD, %p, %x)\n", InputBuffer,
		    InputBufferLength);
	CHECK_KXEN_MODE_NOT(KXEN_MODE_LOADED, "KXENLOAD");
	try {
	    ret = hxen_load((struct hxen_load_desc *)InputBuffer);
	    if (ret == 0)
		SET_KXEN_MODE(KXEN_MODE_LOADED);
	} except (KXEN_EXCEPTION_EXECUTE_HANDLER) {
	    IOCTL_FAILURE(GetExceptionCode(), EINVAL,
			  "hxen_ioctl(KXENLOAD) exception: 0x%08X",
			  IoStatus->Status);
	}
	break;
    case KXENUNLOAD:
	IOCTL_TRACE("hxen_ioctl(KXENUNLOAD)\n");
	CHECK_KXEN_MODE(KXEN_MODE_LOADED, "KXENUNLOAD");
	ret = hxen_unload();
	if (ret == 0)
	    SET_KXEN_MODE(KXEN_MODE_IDLE);
	break;
    case KXENINIT:
	IOCTL_TRACE("hxen_ioctl(KXENINIT)\n");
	CHECK_KXEN_MODE(KXEN_MODE_LOADED, "KXENINIT");
	CHECK_KXEN_MODE_NOT(KXEN_MODE_INITIALIZED, "KXENINIT");
	ret = hxen_init();
	if (ret == 0)
	    SET_KXEN_MODE(KXEN_MODE_INITIALIZED);
	/* XXX per VM/vcpu */
	KeInitializeEvent(&devext->de_runnable, NotificationEvent, FALSE);
	break;
    case KXENSHUTDOWN:
	IOCTL_TRACE("hxen_ioctl(KXENSHUTDOWN)\n");
	CHECK_KXEN_MODE(KXEN_MODE_INITIALIZED, "KXENSHUTDOWN");
	ret = hxen_shutdown();
	if (ret == 0)
	    hxen_mode = KXEN_MODE_LOADED;
	break;
    case KXENREBOOT:
	IOCTL_TRACE("hxen_ioctl(KXENREBOOT)\n");
	kdbgreboot();
	break;
    case KXENKEYHANDLER:
	KXEN_OP("KXENKEYHANDLER", hxen_keyhandler, char, InputBufferLength);
	break;
    case KXENHYPERCALL:
	KXEN_DOM0_CALL("KXENHYPERCALL", hxen_do_hypercall,
		       struct hxen_hypercall_desc);
	break;
    case KXENMEMOP:
	KXEN_DOM0_CALL("KXENMEMOP", hxen_do_memop, struct hxen_memop_desc);
	break;
    case KXENHVMOP:
	KXEN_DOM0_CALL("KXENHVMOP", hxen_do_hvmop, struct hxen_hvmop_desc);
	break;
    case KXENDOMCTL:
	KXEN_DOM0_CALL("KXENDOMCTL", hxen_do_domctl, struct hxen_domctl_desc);
	break;
    case KXENMMAPBATCH:
	KXEN_OP("KXENMMAPBATCH", hxen_mmapbatch, struct hxen_mmapbatch_desc);
	break;
    case KXENMUNMAP:
	KXEN_OP("KXENMUNMAP", hxen_munmap, struct hxen_munmap_desc);
	break;
    case KXENVMAPPINGS:
	KXEN_OP("KXENVMAPPINGS", hxen_create_vmappings,
		struct hxen_vmappings_desc, devext);
	break;
    case KXENEXECUTE:
	IOCTL_TRACE("hxen_ioctl(KXENEXECUTE)\n");
	CHECK_KXEN_MODE(KXEN_MODE_INITIALIZED, "KXENEXECUTE");
	ret = hxen_execute();
	break;
    case KXENSETIOEMUEVENTS:
	IOCTL_TRACE("hxen_ioctl(KXENSETIOEMUEVENTS)\n");
	KXEN_OP("KXENSETIOEMUEVENTS", hxen_set_ioemu_events,
		struct hxen_ioemu_events_desc, devext);
	break;
    default:
	IOCTL_TRACE("hxen_ioctl(%lx)\n", IoControlCode);
	IOCTL_FAILURE(STATUS_NOT_IMPLEMENTED, EINVAL,
		      "hxen_ioctl(%lx) not implemented", IoControlCode);
	break;
    }

  out:
    IoStatus->Information = ret;
    return ioctl_return;
}
