/*++
Copyright (c) Citrix Corporation.  All rights reserved.

Module Name:

    driver.c

Abstract:

    Driver interface routines for 
    Win32 routines to dynamically load and unload a Windows NT kernel-mode
    driver using the Service Control Manager APIs.

    This code borrowed from the MS DDK samples.

Environment:

    User mode only

--*/


#include <windows.h>
#include <winnt.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "tenkei2.h"

void term_printf(const char *fmt, ...) __attribute__ ((__format__ (__printf__, 1, 2)));

BOOLEAN
RemoveDriver(
    SC_HANDLE  SchSCManager,
    LPCTSTR    DriverName
    );

BOOLEAN
StartDriver(
    SC_HANDLE  SchSCManager,
    LPCTSTR    DriverName
    );

BOOLEAN
StopDriver(
    SC_HANDLE  SchSCManager,
    LPCTSTR    DriverName
    );

BOOLEAN
InstallDriver(
    SC_HANDLE  SchSCManager,
    LPCTSTR    DriverName,
    LPCTSTR    ServiceExe
    )
/*++

Routine Description:

Arguments:

Return Value:

--*/
{
    SC_HANDLE   schService;
    DWORD       err;

    //
    // NOTE: This creates an entry for a standalone driver. If this
    //       is modified for use with a driver that requires a Tag,
    //       Group, and/or Dependencies, it may be necessary to
    //       query the registry for existing driver information
    //       (in order to determine a unique Tag, etc.).
    //

    //
    // Create a new a service object.
    //

    schService = CreateService(SchSCManager,           // handle of service control manager database
                               DriverName,             // address of name of service to start
                               DriverName,             // address of display name
                               SERVICE_ALL_ACCESS,     // type of access to service
                               SERVICE_KERNEL_DRIVER,  // type of service
                               SERVICE_DEMAND_START,   // when to start service
                               SERVICE_ERROR_NORMAL,   // severity if service fails to start
                               ServiceExe,             // address of name of binary file
                               NULL,                   // service does not belong to a group
                               NULL,                   // no tag requested
                               NULL,                   // no dependency names
                               NULL,                   // use LocalSystem account
                               NULL                    // no password for service account
                               );

    if (schService == NULL) {

        err = GetLastError();

        if (err == ERROR_SERVICE_EXISTS) {

            //
            // Ignore this error.
            //

            return TRUE;

        } else {

            printf("CreateService failed!  Error = %ld \n", err );

            //
            // Indicate an error.
            //

            return FALSE;
        }
    }

    //
    // Close the service object.
    //

    if (schService) {

        CloseServiceHandle(schService);
    }

    //
    // Indicate success.
    //

    return TRUE;

}   // InstallDriver

BOOLEAN
ManageDriver(
    LPCTSTR  DriverName,
    LPCTSTR  FullPath,
    USHORT   Function
    )
{
    SC_HANDLE   schSCManager;

    BOOLEAN rCode = TRUE;

    if (!DriverName) {
        printf("Invalid Driver (NULL) provided to ManageDriver()\n");
        return FALSE;
    }

    //
    // Connect to the Service Control Manager and open the Services database.
    //

    schSCManager = OpenSCManager(NULL,                   // local machine
                                 NULL,                   // local database
                                 SC_MANAGER_ALL_ACCESS   // access required
                                 );

    if (!schSCManager) {
        printf("Open SC Manager failed! Error = %ld \n", GetLastError());
        return FALSE;
    }

    //
    // Do the requested function.
    //

    switch( Function ) {

        case DRIVER_FUNC_INSTALL:

            //
            // Install the driver service.
            //

            if (!FullPath) {
                printf("Invalid driver path (NULL) provided to ManageDriver()\n");
                return FALSE;
            }

            if (InstallDriver(schSCManager,
                              DriverName,
                              FullPath
                              )) {

                //
                // Start the driver service (i.e. start the driver).
                //

                rCode = StartDriver(schSCManager,
                                    DriverName
                                    );

            } else {

                //
                // Indicate an error.
                //

                rCode = FALSE;
            }
            break;

        case DRIVER_FUNC_REMOVE:

            //
            // Stop the driver.
            //

            StopDriver(schSCManager,
                       DriverName
                       );

            //
            // Remove the driver service.
            //

            RemoveDriver(schSCManager,
                         DriverName
                         );

            //
            // Ignore all errors.
            //

            rCode = TRUE;
            break;

        default:

            printf("Unknown ManageDriver() function. \n");

            rCode = FALSE;
            break;
    }

    //
    // Close handle to service control manager.
    //

    if (schSCManager) {
        CloseServiceHandle(schSCManager);
    }

    return rCode;

}   // ManageDriver


BOOLEAN
RemoveDriver(
    SC_HANDLE    SchSCManager,
    LPCTSTR      DriverName
    )
{
    SC_HANDLE   schService;
    BOOLEAN     rCode;

    //
    // Open the handle to the existing service.
    //

    schService = OpenService(SchSCManager,
                             DriverName,
                             SERVICE_ALL_ACCESS
                             );

    if (schService == NULL) {

        printf("OpenService failed!  Error = %ld \n", GetLastError());

        //
        // Indicate error.
        //

        return FALSE;
    }

    //
    // Mark the service for deletion from the service control manager database.
    //

    if (DeleteService(schService)) {

        //
        // Indicate success.
        //

        rCode = TRUE;

    } else {

        printf("DeleteService failed!  Error = %ld \n", GetLastError());

        //
        // Indicate failure.  Fall through to properly close the service handle.
        //

        rCode = FALSE;
    }

    //
    // Close the service object.
    //

    if (schService) {
        CloseServiceHandle(schService);
    }

    return rCode;

}   // RemoveDriver



BOOLEAN
StartDriver(
    SC_HANDLE    SchSCManager,
    LPCTSTR      DriverName
    )
{
    SC_HANDLE   schService;
    DWORD       err;

    //
    // Open the handle to the existing service.
    //

    schService = OpenService(SchSCManager,
                             DriverName,
                             SERVICE_ALL_ACCESS
                             );

    if (schService == NULL) {

        printf("OpenService failed!  Error = %ld \n", GetLastError());

        //
        // Indicate failure.
        //

        return FALSE;
    }

    //
    // Start the execution of the service (i.e. start the driver).
    //

    if (!StartService(schService,     // service identifier
                      0,              // number of arguments
                      NULL            // pointer to arguments
                      )) {

        err = GetLastError();

        if (err == ERROR_SERVICE_ALREADY_RUNNING) {

            //
            // Ignore this error.
            //

            return TRUE;

        } else {

            printf("StartService failure! Error = %ld\n", err);

            //
            // Indicate failure.  Fall through to properly close the service handle.
            //

            return FALSE;
        }

    }

    //
    // Close the service object.
    //

    if (schService) {

        CloseServiceHandle(schService);
    }

    return TRUE;

}   // StartDriver



BOOLEAN
StopDriver(
    SC_HANDLE    SchSCManager,
    LPCTSTR      DriverName
    )
{
    BOOLEAN         rCode = TRUE;
    SC_HANDLE       schService;
    SERVICE_STATUS  serviceStatus;

    //
    // Open the handle to the existing service.
    //

    schService = OpenService(SchSCManager,
                             DriverName,
                             SERVICE_ALL_ACCESS
                             );

    if (schService == NULL) {
        printf("OpenService failed!  Error = %ld\n", GetLastError());
        return FALSE;
    }

    //
    // Request that the service stop.
    //

    if (ControlService(schService,
                       SERVICE_CONTROL_STOP,
                       &serviceStatus
                       )) {

        //
        // Indicate success.
        //

        rCode = TRUE;

    } else {

        printf("ControlService failed!  Error = %ld\n", GetLastError());

        //
        // Indicate failure.  Fall through to properly close the service handle.
        //

        rCode = FALSE;
    }

    //
    // Close the service object.
    //

    if (schService) {

        CloseServiceHandle (schService);
    }

    return rCode;

}   //  StopDriver


BOOLEAN
SetupDriverName(
    PCHAR DriverLocation,
    ULONG BufferLength
    )
{
    HANDLE fileHandle;
    DWORD driverLocLen = 0;

    //
    // Get the current directory.
    //

    driverLocLen = GetCurrentDirectory(BufferLength,
                                       DriverLocation
                                       );

    if (driverLocLen == 0) {

        printf("GetCurrentDirectory failed!  Error = %ld\n", GetLastError());

        return FALSE;
    }

    //
    // Setup path name to driver file.
    //
    if (strncat(DriverLocation, "\\"DRIVER_NAME".sys", BufferLength) == NULL) {
        return FALSE;
    }

    //
    // Insure driver file is in the specified directory.
    //

    if ((fileHandle = CreateFile(DriverLocation,
                                 GENERIC_READ,
                                 0,
                                 NULL,
                                 OPEN_EXISTING,
                                 FILE_ATTRIBUTE_NORMAL,
                                 NULL
                                 )) == INVALID_HANDLE_VALUE) {


        printf("%s.sys is not loaded.\n", DRIVER_NAME);

        //
        // Indicate failure.
        //

        return FALSE;
    }

    //
    // Close open file handle.
    //

    if (fileHandle) {
        CloseHandle(fileHandle);
    }

    //
    // Indicate success.
    //

    return TRUE;

}   // SetupDriverName



HANDLE
InitDriver()
{
    HANDLE  hDevice;
    UCHAR   driverLocation[MAX_PATH];
    DWORD   errNum = 0;
    MEMDESC blob;

    //
    // Try to connect to driver.  If this fails, try to load the driver
    // dynamically.
    //

    if ((hDevice = CreateFile("\\\\.\\" DRIVER_NAME,
                                 GENERIC_READ,
                                 0,
                                 NULL,
                                 OPEN_EXISTING,
                                 FILE_FLAG_OVERLAPPED,
                                 NULL
                                 )) == INVALID_HANDLE_VALUE) {

        errNum = GetLastError();

        if (errNum != ERROR_FILE_NOT_FOUND) {
            printf("CreateFile failed!  Error = %ld\n", errNum);
            return  INVALID_HANDLE_VALUE;
        }

        //
        // Setup full path to driver name.
        //

        if (!SetupDriverName(driverLocation, sizeof(driverLocation))) {
            return INVALID_HANDLE_VALUE;
        }

        //
        // Before loading the driver, make sure we can also load the
        // Blob.
        //

        if (!LoadBlob(&blob)) {
            return INVALID_HANDLE_VALUE;
        }

        //
        // Install driver.
        //

        if (!ManageDriver(DRIVER_NAME,
                          driverLocation,
                          DRIVER_FUNC_INSTALL
                          )) {

            printf("Unable to install driver. \n");

            //
            // Error - remove driver.
            //

            ManageDriver(DRIVER_NAME,
                         driverLocation,
                         DRIVER_FUNC_REMOVE
                         );
            return INVALID_HANDLE_VALUE;
        }

        //
        // Try to open the newly installed driver.
        //

        hDevice = CreateFile( "\\\\.\\" DRIVER_NAME,
                GENERIC_READ,
                0,
                NULL,
                OPEN_EXISTING,
                FILE_FLAG_OVERLAPPED,
                NULL);

        if ( hDevice == INVALID_HANDLE_VALUE ){
            printf("Error: CreatFile Failed : %ld\n", GetLastError());
            return INVALID_HANDLE_VALUE;
        }

        //
        // Freshly minted driver, install the blob.
        //

        if (!InstallBlob(hDevice, &blob)) {
            //
            // Failed to install blob?  Nothing's going to work.
            //

            CloseHandle(hDevice);
            ManageDriver(DRIVER_NAME,
                         driverLocation,
                         DRIVER_FUNC_REMOVE
                         );
            printf("Error installing blob, closing/unloading driver.\n");
            return INVALID_HANDLE_VALUE;
        }
        printf("Installed/opened new instance of driver.\n");
    } else {
        printf("Opened existing driver.\n");
    }

    // Crank up the system clock.
    timeBeginPeriod(1);
    return hDevice;
}


VOID
ExitDriver(HANDLE hDevice)
{
  term_printf("CloseHandle from %p/%p\n", __builtin_return_address(0),
	      __builtin_return_address(1));
    if (!CloseHandle(hDevice))
	term_printf("CloseHandle failed %ld\n", GetLastError());
    else
	term_printf("CloseHandle success\n");
}


BOOLEAN
Ioctl(
    HANDLE      hDevice,
    ULONG       ControlCode,
    PVOID       InBuffer,
    ULONG       InBufferLength,
    PVOID       OutBuffer,
    ULONG       OutBufferLength,
    PULONG      BytesReturned
    )
{
    BOOLEAN result;

    result = (BOOLEAN)DeviceIoControl(hDevice,
                                      ControlCode, 
                                      InBuffer,
                                      InBufferLength,
                                      OutBuffer,
                                      OutBufferLength,
                                      BytesReturned,
                                      NULL);
    if (!result) {
	fprintf(stderr, "DevIOCTL(%lx) failed %ld\n", ControlCode,
		GetLastError());
    }
    return result;
}

