/*
 *  hxen_mem.c
 *  hxen
 *
 *  Copyright 2009 Citrix Systems, Inc. All rights reserved.
 *
 */

#include "hxen.h"

#include <ntddk.h>
#include <xen/errno.h>

#define KXEN_DEFINE_SYMBOLS_PROTO
#include <hxen_link.h>
KXEN_PROTOTYPES(extern)

static uint32_t *xenheap_pages = NULL;
static unsigned char *xenheap_va = NULL;
static PMDL xenheap_mdl = NULL;

void *
kernel_malloc(size_t size)
{
    return ExAllocatePoolWithTag(NonPagedPool, size, KXEN_POOL_TAG);
}

void
kernel_free(void *addr, size_t size)
{
    ExFreePoolWithTag(addr, KXEN_POOL_TAG);
}

void
kernel_free_heap()
{
    if (xenheap_mdl) {
        if (xenheap_pages) {
            size_t page_count = xenheap_mdl->ByteCount >> PAGE_SHIFT;
            kernel_free(xenheap_pages, page_count);
            xenheap_pages = NULL;
        }
        if (xenheap_va) {
            MmUnmapLockedPages(xenheap_va, xenheap_mdl);
            xenheap_va = NULL;
        }
        MmFreePagesFromMdl(xenheap_mdl);
        ExFreePool(xenheap_mdl);
        xenheap_mdl = NULL;
    }
}

void *
kernel_malloc_heap(uint32_t heap_pages, uint32_t **page_list)
{
    PHYSICAL_ADDRESS low_address;
    PHYSICAL_ADDRESS high_address;
    PHYSICAL_ADDRESS skip_bytes;
    PFN_NUMBER *mfnarray;
    uint32_t i;

    // Allocate the heap from the low 4GB of physical memory so that CR3
    // values (page directory pages) can be allocated from anywhere in the
    // heap.

    low_address.QuadPart = 0;
    high_address.QuadPart = 0xffffffff;
    skip_bytes.QuadPart = 0;
    xenheap_mdl = MmAllocatePagesForMdl(low_address, high_address, skip_bytes,
                                        heap_pages << PAGE_SHIFT);
    if (xenheap_mdl == NULL) {
        fail_msg("kernel_malloc_heap: failed to allocate any memory");
        goto out;
    }
    if (xenheap_mdl->ByteCount != heap_pages << PAGE_SHIFT) {
        fail_msg("kernel_malloc_heap: failed, requested %x, got %x pages",
                 heap_pages, xenheap_mdl->ByteCount >> PAGE_SHIFT);
        goto out;
    }
    xenheap_va = MmMapLockedPagesSpecifyCache(xenheap_mdl, KernelMode,
                                              MmCached, NULL, FALSE,
                                              LowPagePriority);
    if (xenheap_va == NULL) {
        fail_msg("kernel_malloc_heap: failed to get kernel mapping.");
        goto out;
    }

    xenheap_pages = kernel_malloc(heap_pages * sizeof(uint32_t));
    if (xenheap_pages == NULL)
	goto out;
    mfnarray = MmGetMdlPfnArray(xenheap_mdl);
    for (i = 0; i < heap_pages; i++) {
        xenheap_pages[i] = (uint32_t)mfnarray[i];
        if (xenheap_pages[i] > hxen_info->ki_max_page) {
            fail_msg("kernel_malloc_heap: invalid mfn %lx at entry %d\n",
                    xenheap_pages[i], i);
            goto out;
        }
    }
    *page_list = xenheap_pages;
    return xenheap_va;

out:
    kernel_free_heap();
    return NULL;
}

#define MAX_MDL_MAP_PAGES ((65536 - 1 - sizeof(MDL)) / sizeof(PFN_NUMBER))

static void *
mmap_pages(unsigned int num, xen_pfn_t *mfns)
{
    MDL *mdl;
    PFN_NUMBER *pfn;
    void *addr;
    unsigned int i;

    mdl = ExAllocatePoolWithTag(NonPagedPool, sizeof(MDL) +
				sizeof(PFN_NUMBER) * num,
				KXEN_POOL_TAG);
    if (mdl == NULL) {
	fail_msg("hxen_map_pages_to_user: alloc for mdl failed %x\n");
	return NULL;
    }
    memset(mdl, 0, sizeof(*mdl));

    mdl->ByteCount = num << PAGE_SHIFT;

    pfn = MmGetMdlPfnArray(mdl);
    for (i = 0; i < num; i++) {
	pfn[i] = mfns[i];
	if (pfn[i] > hxen_info->ki_max_page) {
	    fail_msg("hxen_map_pages_to_user: invalid mfn %lx at entry %d\n",
		     pfn[i], i);
	    addr = NULL;
	    goto out;
	}
    }

    try {
	addr = MmMapLockedPagesSpecifyCache(mdl, UserMode, MmCached, NULL,
					    FALSE, LowPagePriority);
    } except (KXEN_EXCEPTION_EXECUTE_HANDLER) {
	fail_msg("hxen_map_pages_to_user: "
		 "MmMapLockedPagesSpecifyCache failed %x\n",
		 GetExceptionCode());
	addr = NULL;
    }

  out:
    ExFreePoolWithTag(mdl, KXEN_POOL_TAG);

    return addr;
}

int
hxen_mmapbatch(struct hxen_mmapbatch_desc *kmmapbd)
{
    void *addr;
    int ret;
    int i;
    struct hxen_memop_desc kmemopd;
    union hxen_memop_arg kmemopa;
    xen_pfn_t *mfns;

    if (kmmapbd->kmd_num > MAX_MDL_MAP_PAGES) {
	fail_msg("hxen_map_pages_to_user: %d > MAX_MDL_MAP_PAGES %d\n",
		 kmmapbd->kmd_num, MAX_MDL_MAP_PAGES);
	return EINVAL;
    }

    kmemopa.translate_gpfn.domid = kmmapbd->kmd_domid;
    kmemopa.translate_gpfn.nr_gpfns = kmmapbd->kmd_num;
    kmemopa.translate_gpfn.gpfn_list = kmmapbd->kmd_arr;
    kmemopa.translate_gpfn.mfn_list = kmmapbd->kmd_arr;
    kmemopd.kmd_cmd = XENMEM_translate_gpfn_list;
    set_xen_guest_handle(kmemopd.kmd_kma, &kmemopa);
    ret = hxen_try_call(hxen_do_memop, &kmemopd, hxen_info->ki_dom0_current);
    if (ret != 0) {
	fail_msg("hxen_map_pages_to_user: "
		 "XENMEM_translate_gpfn_list failed %x\n", ret);
	return ret;
    }

    get_xen_guest_handle(mfns, kmmapbd->kmd_arr);
    addr = mmap_pages(kmmapbd->kmd_num, mfns);
    if (addr)
	kmmapbd->kmd_addr = (uint64_t)(uintptr_t)addr;
    else
	ret = EINVAL;

    return ret;
}

static int
munmap_pages(unsigned int num, uint8_t *addr)
{
    MDL *mdl;
    unsigned int i;
    int ret;
    PFN_NUMBER *pfn;
    PHYSICAL_ADDRESS pa;

    mdl = ExAllocatePoolWithTag(NonPagedPool,
				sizeof(MDL) + sizeof(PFN_NUMBER) * num,
				KXEN_POOL_TAG);
    if (mdl == NULL) {
	fail_msg("hxen_unmap_pages_from_user: alloc for mdl failed\n");
	return EINVAL;
    }

    memset(mdl, 0, sizeof(*mdl));

    mdl->ByteCount = num << PAGE_SHIFT;
    mdl->MdlFlags = MDL_PAGES_LOCKED;

    pfn = MmGetMdlPfnArray(mdl);
    for (i = 0; i < num; i++) {
	pa = MmGetPhysicalAddress(addr + (i << PAGE_SHIFT));
	pfn[i] = (PFN_NUMBER)(pa.QuadPart >> PAGE_SHIFT);
    }

    try {
	MmUnmapLockedPages(addr, mdl);
	ret = 0;
    } except (KXEN_EXCEPTION_EXECUTE_HANDLER) {
	fail_msg("hxen_unmap_pages_from_user: "
		 "MmUnmapLockedPages failed %x\n", GetExceptionCode());
	ret = EINVAL;
    }

    ExFreePoolWithTag(mdl, KXEN_POOL_TAG);

    return ret;
}

int
hxen_munmap(struct hxen_munmap_desc *kmd)
{

    return munmap_pages(kmd->kmd_num, (uint8_t *)kmd->kmd_addr);
}

int
hxen_create_vmappings(struct hxen_vmappings_desc *kvd,
		      struct device_extension *devext)
{
    struct hxen_vmapping *vmappings;
    struct hxen_vmap *vmaptable;
    PHYSICAL_ADDRESS lowAddress;
    PHYSICAL_ADDRESS highAddress;
    PHYSICAL_ADDRESS skipBytes;
    int ret = 0;
    int i, j, n;
    uint8_t *va;
    PFN_NUMBER *pfns;

    if (devext->de_nrpages) {
	fail_msg("hxen_create_vmappings: cannot change nr pages");
	return EINVAL;
    }

    devext->de_nrpages = kvd->kvd_nrpages;

    devext->de_vm_info.vi_shared.vi_p2m = kernel_malloc(devext->de_nrpages *
							sizeof(uint32_t));
    if (devext->de_vm_info.vi_shared.vi_p2m == NULL) {
	fail_msg("hxen_create_vmappings: alloc p2m failed");
	ret = ENOMEM;
	goto out;
    }
    memset(devext->de_vm_info.vi_shared.vi_p2m, 0,
	   devext->de_nrpages * sizeof(uint32_t));

#define VMAPPING_MDL_PAGES (1 << KXEN_VMAPPING_SHIFT)

    devext->de_nrmdls = (devext->de_nrpages + VMAPPING_MDL_PAGES - 1) >>
	KXEN_VMAPPING_SHIFT;

    devext->de_mdls = (MDL **)kernel_malloc(devext->de_nrmdls * sizeof(MDL *));
    if (devext->de_mdls == NULL) {
	fail_msg("hxen_create_vmappings: alloc mdl array failed");
	ret = ENOMEM;
	goto out;
    }
    memset(devext->de_mdls, 0, devext->de_nrmdls * sizeof(MDL *));

    get_xen_guest_handle(vmappings, kvd->kvd_vmappings);

    vmaptable = (struct hxen_vmap *)hxen_info->ki_vmaptable;

    lowAddress.QuadPart = 0;
    highAddress.QuadPart = (uint64_t)hxen_info->ki_max_page << PAGE_SHIFT;
    skipBytes.QuadPart = 0;

    for (i = 0; i < devext->de_nrmdls; i++) {
	n = devext->de_nrpages - (i << KXEN_VMAPPING_SHIFT);
	if (n > VMAPPING_MDL_PAGES)
	    n = VMAPPING_MDL_PAGES;
	devext->de_mdls[i] = MmAllocatePagesForMdl(lowAddress, highAddress,
						   skipBytes, n << PAGE_SHIFT);
	if (devext->de_mdls[i] == NULL ||
	    devext->de_mdls[i]->ByteCount != n << PAGE_SHIFT) {
	    fail_msg("hxen_create_vmappings: alloc mdl failed");
	    ret = ENOMEM;
	    goto out;
	}
	try {
	    va = MmMapLockedPagesSpecifyCache(devext->de_mdls[i], UserMode,
					      MmCached, NULL, FALSE,
					      LowPagePriority);
	} except (KXEN_EXCEPTION_EXECUTE_HANDLER) {
	    fail_msg("hxen_create_vmappings: failed to map pages to userspace");
	    ret = EINVAL;
	}
	if (ret)
	    goto out;

	try {
            set_xen_guest_handle(vmappings[i].vaddr, va);
	} except (KXEN_EXCEPTION_EXECUTE_HANDLER) {
	    fail_msg("hxen_create_vmappings: update vmappings failed");
	    ret = EINVAL;
	}
	if (ret)
	    goto out;

	pfns = MmGetMdlPfnArray(devext->de_mdls[i]);
	for (j = 0; j < n; j++) {
            /* N.B. pfns are stored at 32 bit values so must be below 16TB */
	    ASSERT(pfns[j] < hxen_info->ki_max_page);
	    devext->de_vm_info.vi_shared.vi_p2m
		[(i << KXEN_VMAPPING_SHIFT) + j] = (uint32_t)pfns[j];
	    vmaptable[pfns[j]].vaddr = (va + (j << PAGE_SHIFT));
	}
    }

    devext->de_vm_info.vi_shared.vi_nrpages = devext->de_nrpages;

  out:
    if (ret)
	hxen_free_vmappings(devext);
    return ret;
}

void
hxen_free_vmappings(struct device_extension *devext)
{
    int i;

    if (devext->de_mdls) {
	for (i = 0; i < devext->de_nrmdls; i++) {
	    if (devext->de_mdls[i] == NULL)
		break;
	    MmFreePagesFromMdl(devext->de_mdls[i]);
	    ExFreePool(devext->de_mdls[i]);
	}
	kernel_free(devext->de_mdls, devext->de_nrmdls * sizeof(MDL *));
	devext->de_mdls = NULL;
    }

    if (devext->de_vm_info.vi_shared.vi_p2m)
	kernel_free(devext->de_vm_info.vi_shared.vi_p2m,
		    devext->de_nrpages * sizeof(uint32_t));
    devext->de_vm_info.vi_shared.vi_p2m = NULL;

    devext->de_nrpages = 0;
}

void * __cdecl
hxen_map_pages(unsigned int num, xen_pfn_t *mfns)
{

    return mmap_pages(num, mfns);
}

int __cdecl
hxen_unmap_pages(unsigned int num, void *va)
{

    return munmap_pages(num, va);
}
