# QEMU Manual
# https://qemu.weilnetz.de/doc/qemu-doc.html

# For QEMU Monitor Protocol Commands Information, See
# https://qemu.weilnetz.de/doc/qemu-doc.html#pcsys_005fmonitor

import os
import random
import subprocess as sp
import tempfile
import time

from functools import wraps
from os.path import join as join_path
from string import Template
from typing import Union

import bitmath
import sshtunnel

from common.helpers import get_ipv6_address
from common.request import RequestEntry, RequestType
from common.vm import VMEntry, VMStatus
from config import etcd_client, request_pool, running_vms, vm_pool, env_vars, image_storage_handler
from . import qmp
from host import logger


class VM:
    def __init__(self, key, handle, vnc_socket_file):
        self.key = key  # type: str
        self.handle = handle  # type: qmp.QEMUMachine
        self.vnc_socket_file = vnc_socket_file  # type: tempfile.NamedTemporaryFile

    def __repr__(self):
        return "VM({})".format(self.key)


def create_dev(script, _id, dev, ip=None):
    command = [script, _id, dev]
    if ip:
        command.append(ip)
    try:
        output = sp.check_output(command, stderr=sp.PIPE)
    except Exception as e:
        print(e.stderr)
        return None
    else:
        return output.decode("utf-8").strip()


def create_vxlan_br_tap(_id, _dev, ip=None):
    network_script_base = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'network')
    vxlan = create_dev(script=os.path.join(network_script_base, 'create-vxlan.sh'),
                       _id=_id, dev=_dev)
    if vxlan:
        bridge = create_dev(script=os.path.join(network_script_base, 'create-bridge.sh'),
                            _id=_id, dev=vxlan, ip=ip)
        if bridge:
            tap = create_dev(script=os.path.join(network_script_base, 'create-tap.sh'),
                             _id=str(random.randint(1, 100000)), dev=bridge)
            if tap:
                return tap


def random_bytes(num=6):
    return [random.randrange(256) for _ in range(num)]


def generate_mac(uaa=False, multicast=False, oui=None, separator=':', byte_fmt='%02x'):
    mac = random_bytes()
    if oui:
        if type(oui) == str:
            oui = [int(chunk) for chunk in oui.split(separator)]
        mac = oui + random_bytes(num=6 - len(oui))
    else:
        if multicast:
            mac[0] |= 1  # set bit 0
        else:
            mac[0] &= ~1  # clear bit 0
        if uaa:
            mac[0] &= ~(1 << 1)  # clear bit 1
        else:
            mac[0] |= 1 << 1  # set bit 1
    return separator.join(byte_fmt % b for b in mac)


def update_radvd_conf(etcd_client):
    network_script_base = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'network')

    networks = {
        net.value['ipv6']: net.value['id']
        for net in etcd_client.get_prefix('/v1/network/', value_in_json=True)
        if net.value.get('ipv6')
    }
    radvd_template = open(os.path.join(network_script_base,
                                       'radvd-template.conf'), 'r').read()
    radvd_template = Template(radvd_template)

    content = [radvd_template.safe_substitute(bridge='br{}'.format(networks[net]),
                                              prefix=net)
               for net in networks if networks.get(net)]

    with open('/etc/radvd.conf', 'w') as radvd_conf:
        radvd_conf.writelines(content)

    sp.check_output(['systemctl', 'restart', 'radvd'])


def get_start_command_args(vm_entry, vnc_sock_filename: str, migration=False, migration_port=None):
    threads_per_core = 1
    vm_memory = int(bitmath.parse_string_unsafe(vm_entry.specs["ram"]).to_MB())
    vm_cpus = int(vm_entry.specs["cpu"])
    vm_uuid = vm_entry.uuid
    vm_networks = vm_entry.network

    command = "-name {}_{}".format(vm_entry.owner, vm_entry.name)

    command += " -drive file={},format=raw,if=virtio,cache=none".format(
        image_storage_handler.qemu_path_string(vm_uuid)
    )
    command += " -device virtio-rng-pci -vnc unix:{}".format(vnc_sock_filename)
    command += " -m {} -smp cores={},threads={}".format(
        vm_memory, vm_cpus, threads_per_core
    )

    if migration:
        command += " -incoming tcp:[::]:{}".format(migration_port)

    tap = None
    for network_and_mac in vm_networks:
        network_name, mac = network_and_mac

        _key = os.path.join(env_vars.get('NETWORK_PREFIX'), vm_entry.owner, network_name)
        network = etcd_client.get(_key, value_in_json=True)
        network_type = network.value["type"]
        network_id = str(network.value["id"])
        network_ipv6 = network.value["ipv6"]

        if network_type == "vxlan":
            tap = create_vxlan_br_tap(network_id, env_vars.get("VXLAN_PHY_DEV"), network_ipv6)
            update_radvd_conf(etcd_client)

        command += " -netdev tap,id=vmnet{net_id},ifname={tap},script=no,downscript=no" \
                   " -device virtio-net-pci,netdev=vmnet{net_id},mac={mac}" \
            .format(tap=tap, net_id=network_id, mac=mac)

    return command.split(" ")


def create_vm_object(vm_entry, migration=False, migration_port=None):
    # NOTE: If migration suddenly stop working, having different
    #       VNC unix filename on source and destination host can
    #       be a possible cause of it.

    # REQUIREMENT: Use Unix Socket instead of TCP Port for VNC
    vnc_sock_file = tempfile.NamedTemporaryFile()

    qemu_args = get_start_command_args(
        vm_entry=vm_entry,
        vnc_sock_filename=vnc_sock_file.name,
        migration=migration,
        migration_port=migration_port,
    )
    qemu_machine = qmp.QEMUMachine("/usr/bin/qemu-system-x86_64", args=qemu_args)
    return VM(vm_entry.key, qemu_machine, vnc_sock_file)


def get_vm(vm_list: list, vm_key) -> Union[VM, None]:
    return next((vm for vm in vm_list if vm.key == vm_key), None)


def need_running_vm(func):
    @wraps(func)
    def wrapper(e):
        vm = get_vm(running_vms, e.key)
        if vm:
            try:
                status = vm.handle.command("query-status")
                logger.debug("VM Status Check - %s", status)
            except Exception as exception:
                logger.info("%s failed - VM %s %s", func.__name__, e, exception)
            else:
                return func(e)

            return None
        else:
            logger.info("%s failed because VM %s is not running", func.__name__, e.key)
            return None

    return wrapper


def create(vm_entry: VMEntry):
    if image_storage_handler.is_vm_image_exists(vm_entry.uuid):
        # File Already exists. No Problem Continue
        logger.debug("Image for vm %s exists", vm_entry.uuid)
    else:
        vm_hdd = int(bitmath.parse_string_unsafe(vm_entry.specs["os-ssd"]).to_MB())
        if image_storage_handler.make_vm_image(src=vm_entry.image_uuid, dest=vm_entry.uuid):
            if not image_storage_handler.resize_vm_image(path=vm_entry.uuid, size=vm_hdd):
                vm_entry.status = VMStatus.error
            else:
                logger.info("New VM Created")


def start(vm_entry: VMEntry, destination_host_key=None, migration_port=None):
    _vm = get_vm(running_vms, vm_entry.key)

    # VM already running. No need to proceed further.
    if _vm:
        logger.info("VM %s already running" % vm_entry.uuid)
        return
    else:
        logger.info("Trying to start %s" % vm_entry.uuid)
        if destination_host_key:
            launch_vm(vm_entry, migration=True, migration_port=migration_port,
                      destination_host_key=destination_host_key)
        else:
            create(vm_entry)
            launch_vm(vm_entry)


@need_running_vm
def stop(vm_entry):
    vm = get_vm(running_vms, vm_entry.key)
    vm.handle.shutdown()
    if not vm.handle.is_running():
        vm_entry.add_log("Shutdown successfully")
        vm_entry.declare_stopped()
        vm_pool.put(vm_entry)
        running_vms.remove(vm)


def delete(vm_entry):
    logger.info("Deleting VM | %s", vm_entry)
    stop(vm_entry)

    r_status = image_storage_handler.delete_vm_image(vm_entry.uuid)
    if r_status:
        etcd_client.client.delete(vm_entry.key)


def transfer(request_event):
    # This function would run on source host i.e host on which the vm
    # is running initially. This host would be responsible for transferring
    # vm state to destination host.

    _host, _port = request_event.parameters["host"], request_event.parameters["port"]
    _uuid = request_event.uuid
    _destination = request_event.destination_host_key
    vm = get_vm(running_vms, join_path(env_vars.get('VM_PREFIX'), _uuid))

    if vm:
        tunnel = sshtunnel.SSHTunnelForwarder(
            _host,
            ssh_username=env_vars.get("ssh_username"),
            ssh_pkey=env_vars.get("ssh_pkey"),
            remote_bind_address=("127.0.0.1", _port),
            ssh_proxy_enabled=True,
            ssh_proxy=(_host, 22)
        )
        try:
            tunnel.start()
        except sshtunnel.BaseSSHTunnelForwarderError:
            logger.exception("Couldn't establish connection to (%s, 22)", _host)
        else:
            vm.handle.command(
                "migrate", uri="tcp:0.0.0.0:{}".format(tunnel.local_bind_port)
            )

            status = vm.handle.command("query-migrate")["status"]
            while status not in ["failed", "completed"]:
                time.sleep(2)
                status = vm.handle.command("query-migrate")["status"]

            with vm_pool.get_put(request_event.uuid) as source_vm:
                if status == "failed":
                    source_vm.add_log("Migration Failed")
                elif status == "completed":
                    # If VM is successfully migrated then shutdown the VM
                    # on this host and update hostname to destination host key
                    source_vm.add_log("Successfully migrated")
                    source_vm.hostname = _destination
                    running_vms.remove(vm)
                    vm.handle.shutdown()
                source_vm.in_migration = False  # VM transfer finished
        finally:
            tunnel.close()


def launch_vm(vm_entry, migration=False, migration_port=None, destination_host_key=None):
    logger.info("Starting %s" % vm_entry.key)

    vm = create_vm_object(vm_entry, migration=migration, migration_port=migration_port)
    try:
        vm.handle.launch()
    except Exception:
        logger.exception("Error Occured while starting VM")
        vm.handle.shutdown()

        if migration:
            # We don't care whether MachineError or any other error occurred
            pass
        else:
            # Error during typical launch of a vm
            vm.handle.shutdown()
            vm_entry.declare_killed()
            vm_pool.put(vm_entry)
    else:
        vm_entry.vnc_socket = vm.vnc_socket_file.name
        running_vms.append(vm)

        if migration:
            vm_entry.in_migration = True
            r = RequestEntry.from_scratch(
                type=RequestType.TransferVM,
                hostname=vm_entry.hostname,
                parameters={"host": get_ipv6_address(), "port": migration_port},
                uuid=vm_entry.uuid,
                destination_host_key=destination_host_key,
                request_prefix=env_vars.get("REQUEST_PREFIX")
            )
            request_pool.put(r)
        else:
            # Typical launching of a vm
            vm_entry.status = VMStatus.running
            vm_entry.add_log("Started successfully")

        vm_pool.put(vm_entry)