/*++
Copyright (c) Citrix Corporation.  All rights reserved.

Module Name:

    patchxp.c

Abstract:

    Load/Unload and run the Xen XP patch driver.

Environment:

    User mode only

--*/


#include <windows.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <strsafe.h>

#define DRIVER_FUNC_INSTALL     0x01
#define DRIVER_FUNC_REMOVE      0x02

#define DRIVER_NAME       "XenPatch"

HANDLE
OpenDriver(
    void
    );

VOID
CloseDriver(
    HANDLE
    );

BOOLEAN
RemoveDriver(
    __in SC_HANDLE  SchSCManager,
    __in LPCTSTR    DriverName
    );

BOOLEAN
StartDriver(
    __in SC_HANDLE  SchSCManager,
    __in LPCTSTR    DriverName
    );

BOOLEAN
StopDriver(
    __in SC_HANDLE  SchSCManager,
    __in LPCTSTR    DriverName
    );

BOOLEAN
InstallDriver(
    __in SC_HANDLE  SchSCManager,
    __in LPCTSTR    DriverName,
    __in LPCTSTR    ServiceExe
    )
/*++

Routine Description:

Arguments:

Return Value:

--*/
{
    SC_HANDLE   schService;
    DWORD       err;

    //
    // NOTE: This creates an entry for a standalone driver. If this
    //       is modified for use with a driver that requires a Tag,
    //       Group, and/or Dependencies, it may be necessary to
    //       query the registry for existing driver information
    //       (in order to determine a unique Tag, etc.).
    //

    //
    // Create a new a service object.
    //

    schService = CreateService(SchSCManager,           // handle of service control manager database
                               DriverName,             // address of name of service to start
                               DriverName,             // address of display name
                               SERVICE_ALL_ACCESS,     // type of access to service
                               SERVICE_KERNEL_DRIVER,  // type of service
                               SERVICE_DEMAND_START,   // when to start service
                               SERVICE_ERROR_NORMAL,   // severity if service fails to start
                               ServiceExe,             // address of name of binary file
                               NULL,                   // service does not belong to a group
                               NULL,                   // no tag requested
                               NULL,                   // no dependency names
                               NULL,                   // use LocalSystem account
                               NULL                    // no password for service account
                               );

    if (schService == NULL) {

        err = GetLastError();

        if (err == ERROR_SERVICE_EXISTS) {

            //
            // Ignore this error.
            //

            return TRUE;

        } else {

            printf("CreateService failed!  Error = %d \n", err );

            //
            // Indicate an error.
            //

            return FALSE;
        }
    }

    //
    // Close the service object.
    //

    if (schService) {

        CloseServiceHandle(schService);
    }

    //
    // Indicate success.
    //

    return TRUE;

}   // InstallDriver

BOOLEAN
ManageDriver(
    __in LPCTSTR  DriverName,
    __in LPCTSTR  FullPath,
    __in USHORT   Function
    )
{
    SC_HANDLE   schSCManager;

    BOOLEAN rCode = TRUE;

    if (!DriverName) {
        printf("Invalid Driver (NULL) provided to ManageDriver()\n");
        return FALSE;
    }

    //
    // Connect to the Service Control Manager and open the Services database.
    //

    schSCManager = OpenSCManager(NULL,                   // local machine
                                 NULL,                   // local database
                                 SC_MANAGER_ALL_ACCESS   // access required
                                 );

    if (!schSCManager) {
        printf("Open SC Manager failed! Error = %d \n", GetLastError());
        return FALSE;
    }

    //
    // Do the requested function.
    //

    switch( Function ) {

        case DRIVER_FUNC_INSTALL:

            //
            // Install the driver service.
            //

            if (!FullPath) {
                printf("Invalid driver path (NULL) provided to ManageDriver()\n");
                return FALSE;
            }

            if (InstallDriver(schSCManager,
                              DriverName,
                              FullPath
                              )) {

                //
                // Start the driver service (i.e. start the driver).
                //

                rCode = StartDriver(schSCManager,
                                    DriverName
                                    );

            } else {

                //
                // Indicate an error.
                //

                rCode = FALSE;
            }
            break;

        case DRIVER_FUNC_REMOVE:

            //
            // Stop the driver.
            //

            StopDriver(schSCManager,
                       DriverName
                       );

            //
            // Remove the driver service.
            //

            RemoveDriver(schSCManager,
                         DriverName
                         );

            //
            // Ignore all errors.
            //

            rCode = TRUE;
            break;

        default:

            printf("Unknown ManageDriver() function. \n");

            rCode = FALSE;
            break;
    }

    //
    // Close handle to service control manager.
    //

    if (schSCManager) {
        CloseServiceHandle(schSCManager);
    }

    return rCode;

}   // ManageDriver


BOOLEAN
RemoveDriver(
    __in SC_HANDLE    SchSCManager,
    __in LPCTSTR      DriverName
    )
{
    SC_HANDLE   schService;
    BOOLEAN     rCode;

    //
    // Open the handle to the existing service.
    //

    schService = OpenService(SchSCManager,
                             DriverName,
                             SERVICE_ALL_ACCESS
                             );

    if (schService == NULL) {

        printf("OpenService failed!  Error = %d \n", GetLastError());

        //
        // Indicate error.
        //

        return FALSE;
    }

    //
    // Mark the service for deletion from the service control manager database.
    //

    if (DeleteService(schService)) {

        //
        // Indicate success.
        //

        rCode = TRUE;

    } else {

        printf("DeleteService failed!  Error = %d \n", GetLastError());

        //
        // Indicate failure.  Fall through to properly close the service handle.
        //

        rCode = FALSE;
    }

    //
    // Close the service object.
    //

    if (schService) {
        CloseServiceHandle(schService);
    }

    return rCode;

}   // RemoveDriver



BOOLEAN
StartDriver(
    __in SC_HANDLE    SchSCManager,
    __in LPCTSTR      DriverName
    )
{
    SC_HANDLE   schService;
    DWORD       err;

    //
    // Open the handle to the existing service.
    //

    schService = OpenService(SchSCManager,
                             DriverName,
                             SERVICE_ALL_ACCESS
                             );

    if (schService == NULL) {

        printf("OpenService failed!  Error = %d \n", GetLastError());

        //
        // Indicate failure.
        //

        return FALSE;
    }

    //
    // Start the execution of the service (i.e. start the driver).
    //

    if (!StartService(schService,     // service identifier
                      0,              // number of arguments
                      NULL            // pointer to arguments
                      )) {

        err = GetLastError();

        if (err == ERROR_SERVICE_ALREADY_RUNNING) {

            //
            // Ignore this error.
            //

            return TRUE;

        } else {

            printf("StartService failure! Error = %d \n", err );

            //
            // Indicate failure.  Fall through to properly close the service handle.
            //

            return FALSE;
        }

    }

    //
    // Close the service object.
    //

    if (schService) {

        CloseServiceHandle(schService);
    }

    return TRUE;

}   // StartDriver



BOOLEAN
StopDriver(
    __in SC_HANDLE    SchSCManager,
    __in LPCTSTR      DriverName
    )
{
    BOOLEAN         rCode = TRUE;
    SC_HANDLE       schService;
    SERVICE_STATUS  serviceStatus;

    //
    // Open the handle to the existing service.
    //

    schService = OpenService(SchSCManager,
                             DriverName,
                             SERVICE_ALL_ACCESS
                             );

    if (schService == NULL) {
        printf("OpenService failed!  Error = %d \n", GetLastError());
        return FALSE;
    }

    //
    // Request that the service stop.
    //

    if (ControlService(schService,
                       SERVICE_CONTROL_STOP,
                       &serviceStatus
                       )) {

        //
        // Indicate success.
        //

        rCode = TRUE;

    } else {

        printf("ControlService failed!  Error = %d \n", GetLastError() );

        //
        // Indicate failure.  Fall through to properly close the service handle.
        //

        rCode = FALSE;
    }

    //
    // Close the service object.
    //

    if (schService) {

        CloseServiceHandle (schService);
    }

    return rCode;

}   //  StopDriver


BOOLEAN
SetupDriverName(
    __inout_bcount_full(BufferLength) PCHAR DriverLocation,
    __in ULONG BufferLength
    )
{
    HANDLE fileHandle;
    DWORD driverLocLen = 0;

    //
    // Get the current directory.
    //

    driverLocLen = GetCurrentDirectory(BufferLength,
                                       DriverLocation
                                       );

    if (driverLocLen == 0) {

        printf("GetCurrentDirectory failed!  Error = %d \n", GetLastError());

        return FALSE;
    }

    //
    // Setup path name to driver file.
    //
    if (FAILED( StringCbCat(DriverLocation, BufferLength, "\\"DRIVER_NAME".sys") )) {
        return FALSE;
    }

    //
    // Insure driver file is in the specified directory.
    //

    if ((fileHandle = CreateFile(DriverLocation,
                                 GENERIC_READ,
                                 0,
                                 NULL,
                                 OPEN_EXISTING,
                                 FILE_ATTRIBUTE_NORMAL,
                                 NULL
                                 )) == INVALID_HANDLE_VALUE) {


        printf("%s.sys is not loaded.\n", DRIVER_NAME);

        //
        // Indicate failure.
        //

        return FALSE;
    }

    //
    // Close open file handle.
    //

    if (fileHandle) {
        CloseHandle(fileHandle);
    }

    //
    // Indicate success.
    //

    return TRUE;

}   // SetupDriverName


HANDLE
OpenDriver()
{
    HANDLE  hDevice;
    UCHAR   driverLocation[MAX_PATH];
    DWORD   errNum = 0;

    //
    // Setup full path to driver name.
    //

    if (!SetupDriverName(driverLocation, sizeof(driverLocation))) {
        return INVALID_HANDLE_VALUE;
    }

    //
    // Install driver.
    //

    if (!ManageDriver(DRIVER_NAME,
                      driverLocation,
                      DRIVER_FUNC_INSTALL
                      )) {

        printf("Unable to install driver. \n");

        //
        // Error - remove driver.
        //

        ManageDriver(DRIVER_NAME,
                     driverLocation,
                     DRIVER_FUNC_REMOVE
                     );
        return INVALID_HANDLE_VALUE;
    }

#if 0 // PLJ, no create support in driver.
    //
    // Try to open the newly installed driver.
    //

    hDevice = CreateFile( "\\\\.\\" DRIVER_NAME,
            GENERIC_READ,
            0,
            NULL,
            OPEN_EXISTING,
            FILE_FLAG_OVERLAPPED,
            NULL);

    if ( hDevice == INVALID_HANDLE_VALUE ){
        printf("Error: CreatFile Failed : %d\n", GetLastError());
        return INVALID_HANDLE_VALUE;
    }
    return hDevice;
#else
    return INVALID_HANDLE_VALUE;
#endif
}


VOID
CloseDriver(HANDLE hDevice)
{
    CloseHandle(hDevice);
}


//
// Main function
//

VOID __cdecl
main(
    __in ULONG argc,
    __in_ecount(argc) PCHAR argv[]
    )
{
    HANDLE hDevice;

    hDevice = OpenDriver();

#if 0
    if (hDevice == INVALID_HANDLE_VALUE) {
        ExitProcess(1);
        return;
    }
#endif

    //
    // Close the driver, we don't use the open instance for anything.
    // (Ok, don't close it either, as we didn't actually open it).
    //
    // N.B. Don't unload the driver, the patches branch into driver code!
    // (unless you move it to allocated memory).
    //

#if 0
    CloseDriver(hDevice);
#endif

    ExitProcess(0);
}
