#!/usr/bin/env python
#
# This is a save process which also buffers outgoing I/O between
# rounds, so that external viewers never see anything that hasn't
# been committed at the backup
#
# TODO: fencing.

import optparse, os, re, select, signal, sys, time
from xen.remus import save, vm
from xen.xend import XendOptions
from xen.remus import netlink, qdisc, util

class CfgException(Exception): pass

class Cfg(object):
    def __init__(self):
        # must be set
        self.domid = 0

        self.host = 'localhost'
        self.port = XendOptions.instance().get_xend_relocation_port()
        self.interval = 200
        self.netbuffer = True
        self.nobackup = False
        self.timer = False

        parser = optparse.OptionParser()
        parser.usage = '%prog [options] domain [destination]'
        parser.add_option('-i', '--interval', dest='interval', type='int',
                          metavar='MS',
                          help='checkpoint every MS milliseconds')
        parser.add_option('-p', '--port', dest='port', type='int',
                          help='send stream to port PORT', metavar='PORT')
        parser.add_option('', '--no-net', dest='nonet', action='store_true',
                          help='run without net buffering (benchmark option)')
        parser.add_option('', '--timer', dest='timer', action='store_true',
                          help='force pause at checkpoint interval (experimental)')
        parser.add_option('', '--no-backup', dest='nobackup',
                          action='store_true',
                          help='prevent backup from starting up (benchmark '
                          'option)')
        self.parser = parser

    def usage(self):
        self.parser.print_help()

    def getargs(self):
        opts, args = self.parser.parse_args()

        if opts.interval:
            self.interval = opts.interval
        if opts.port:
            self.port = opts.port
        if opts.nonet:
            self.netbuffer = False
        if opts.timer:
            self.timer = True

        if not args:
            raise CfgException('Missing domain')
        self.domid = args[0]
        if (len(args) > 1):
            self.host = args[1]

class ReplicatedDiskException(Exception): pass

class BufferedDevice(object):
    'Base class for buffered devices'

    def postsuspend(self):
        'called after guest has suspended'
        pass

    def preresume(self):
        'called before guest resumes'
        pass

    def commit(self):
        'called when backup has acknowledged checkpoint reception'
        pass

class ReplicatedDisk(BufferedDevice):
    """
    Send a checkpoint message to a replicated disk while the domain
    is paused between epochs.
    """
    FIFODIR = '/var/run/tap'

    def __init__(self, disk):
        # look up disk, make sure it is tap:buffer, and set up socket
        # to request commits.
        self.ctlfd = None

        if not disk.uname.startswith('tap:remus:') and not disk.uname.startswith('tap:tapdisk:remus:'):
            raise ReplicatedDiskException('Disk is not replicated: %s' %
                                        str(disk))
        fifo = re.match("tap:.*(remus.*)\|", disk.uname).group(1).replace(':', '_')
        absfifo = os.path.join(self.FIFODIR, fifo)
        absmsgfifo = absfifo + '.msg'

        self.installed = False
        self.ctlfd = open(absfifo, 'w+b')
        self.msgfd = open(absmsgfifo, 'r+b')

    def __del__(self):
        self.uninstall()

    def setup(self):
        #self.ctlfd.write('buffer')
        #self.ctlfd.flush()
        self.installed = True

    def uninstall(self):
        if self.ctlfd:
            self.ctlfd.close()
            self.ctlfd = None

    def postsuspend(self):
        if not self.installed:
            self.setup()

        os.write(self.ctlfd.fileno(), 'flush')

    def commit(self):
        msg = os.read(self.msgfd.fileno(), 4)
        if msg != 'done':
            print 'Unknown message: %s' % msg

class NetbufferException(Exception): pass

class Netbuffer(BufferedDevice):
    """
    Buffer a protected domain's network output between rounds so that
    nothing is issued that a failover might not know about.
    """
    # shared rtnetlink handle
    rth = None

    def __init__(self, domid):
        self.installed = False

        if not self.rth:
            self.rth = netlink.rtnl()

        self.devname = self._startimq(domid)
        dev = self.rth.getlink(self.devname)
        if not dev:
            raise NetbufferException('could not find device %s' % self.devname)
        self.dev = dev['index']
        self.handle = qdisc.TC_H_ROOT
        self.q = qdisc.QueueQdisc()

    def __del__(self):
        self.uninstall()

    def postsuspend(self):
        if not self.installed:
            self._setup()

        self._sendqmsg(qdisc.TC_QUEUE_CHECKPOINT)

    def commit(self):
        '''Called when checkpoint has been acknowledged by
        the backup'''
        self._sendqmsg(qdisc.TC_QUEUE_RELEASE)

    def _sendqmsg(self, action):
        self.q.action = action
        req = qdisc.changerequest(self.dev, self.handle, self.q)
        self.rth.talk(req.pack())

    def _setup(self):
        q = self.rth.getqdisc(self.dev)
        if q:
            if q['kind'] == 'queue':
                self.installed = True
                return
            if q['kind'] != 'pfifo_fast':
                raise NetbufferException('there is already a queueing '
                                         'discipline on %s' % self.devname)

        print 'installing buffer on %s' % self.devname
        req = qdisc.addrequest(self.dev, self.handle, self.q)
        self.rth.talk(req.pack())
        self.installed = True

    def uninstall(self):
        if self.installed:
            req = qdisc.delrequest(self.dev, self.handle)
            self.rth.talk(req.pack())
            self.installed = False

    def _startimq(self, domid):
        # stopgap hack to set up IMQ for an interface. Wrong in many ways.
        imqebt = '/usr/lib/xen/bin/imqebt'
        imqdev = 'imq0'
        vid = 'vif%d.0' % domid
        for mod in ['sch_queue', 'imq', 'ebt_imq']:
            util.runcmd(['modprobe', mod])
        util.runcmd("ip link set %s up" % (imqdev))
        util.runcmd("%s -F FORWARD" % (imqebt))
        util.runcmd("%s -A FORWARD -i %s -j imq --todev %s" % (imqebt, vid, imqdev))

        return imqdev

class SignalException(Exception): pass

def run(cfg):
    closure = lambda: None
    closure.cmd = None

    def sigexception(signo, frame):
        raise SignalException(signo)

    def die():
        # I am not sure what the best way to die is. xm destroy is another option,
        # or we could attempt to trigger some instant reboot.
        print "dying..."
        print util.runcmd(['sudo', 'ifdown', 'eth2'])
        # dangling imq0 handle on vif locks up the system
        for buf in bufs:
            buf.uninstall()
        print util.runcmd(['sudo', 'xm', 'destroy', cfg.domid])
        print util.runcmd(['sudo', 'ifup', 'eth2'])

    def getcommand():
        """Get a command to execute while running.
        Commands include:
          s: die prior to postsuspend hook
          s2: die after postsuspend hook
          r: die prior to preresume hook
          r2: die after preresume hook
          c: die prior to commit hook
          c2: die after commit hook
          """
        r, w, x = select.select([sys.stdin], [], [], 0)
        if sys.stdin not in r:
            return

        cmd = sys.stdin.readline().strip()
        if cmd not in ('s', 's2', 'r', 'r2', 'c', 'c2'):
            print "unknown command: %s" % cmd
        closure.cmd = cmd

    signal.signal(signal.SIGTERM, sigexception)

    dom = vm.VM(cfg.domid)

    # set up I/O buffers
    bufs = []

    # disks must commit before network can be released
    for disk in dom.disks:
        try:
            bufs.append(ReplicatedDisk(disk))
        except ReplicatedDiskException, e:
            print e
            continue

    if cfg.netbuffer:
        for vif in dom.vifs:
            bufs.append(Netbuffer(dom.domid))

    fd = save.MigrationSocket((cfg.host, cfg.port))

    def postsuspend():
        'Begin external checkpointing after domain has paused'
        if not cfg.timer:
            # when not using a timer thread, sleep until now + interval
            closure.starttime = time.time()

        if closure.cmd == 's':
            die()

        for buf in bufs:
            buf.postsuspend()

        if closure.cmd == 's2':
            die()

    def preresume():
        'Complete external checkpointing before domain resumes'
        if closure.cmd == 'r':
            die()

        for buf in bufs:
            buf.preresume()

        if closure.cmd == 'r2':
            die()

    def commit():
        'commit network buffer'
        if closure.cmd == 'c':
            die()

        print >> sys.stderr, "PROF: flushed memory at %0.6f" % (time.time())

        for buf in bufs:
            buf.commit()

        if closure.cmd == 'c2':
            die()

        # Since the domain is running at this point, it's a good time to
        # check for control channel commands
        getcommand()

        if not cfg.timer:
            endtime = time.time()
            elapsed = (endtime - closure.starttime) * 1000

            if elapsed < cfg.interval:
                time.sleep((cfg.interval - elapsed) / 1000.0)

        # False ends checkpointing
        return True

    if cfg.timer:
        interval = cfg.interval
    else:
        interval = 0

    rc = 0

    checkpointer = save.Saver(cfg.domid, fd, postsuspend, preresume, commit,
                              interval)

    try:
        checkpointer.start()
    except save.CheckpointError, e:
        print e
        rc = 1
    except KeyboardInterrupt:
        pass
    except SignalException:
        print '*** signalled ***'

    for buf in bufs:
        buf.uninstall()

    if cfg.nobackup:
        # lame attempt to kill backup if protection is stopped deliberately.
        # It would be much better to move this into the heartbeat "protocol".
        print util.runcmd(['sudo', '-u', os.getlogin(), 'ssh', cfg.host, 'sudo', 'xm', 'destroy', dom.name])

    sys.exit(rc)

cfg = Cfg()
try:
    cfg.getargs()
except CfgException, inst:
    print str(inst)
    cfg.usage()
    sys.exit(1)

try:
    run(cfg)
except vm.VMException, inst:
    print str(inst)
    sys.exit(1)
