"""
This module contain classes thats validates and intercept/modify
data coming from ucloud-cli (user)

It was primarily developed as an alternative to argument parser
of Flask_Restful which is going to be deprecated. I also tried
marshmallow for that purpose but it was an overkill (because it
do validation + serialization + deserialization) and little
inflexible for our purpose.
"""

# TODO: Fix error message when user's mentioned VM (referred by name)
#       does not exists.
#
#       Currently, it says uuid is a required field.

import json
import os

import bitmath

from ucloud.common.host import HostStatus
from ucloud.common.vm import VMStatus
from ucloud.config import etcd_client, config, vm_pool, host_pool
from . import helper
from .common_fields import Field, VmUUIDField
from .helper import check_otp, resolve_vm_name


class BaseSchema:
    def __init__(self, data, fields=None):
        _ = data  # suppress linter warning
        self.__errors = []
        if fields is None:
            self.fields = []
        else:
            self.fields = fields

    def validation(self):
        # custom validation is optional
        return True

    def is_valid(self):
        for field in self.fields:
            field.is_valid()
            self.add_field_errors(field)

        for parent in self.__class__.__bases__:
            try:
                parent.validation(self)
            except AttributeError:
                pass
        if not self.__errors:
            self.validation()

        if self.__errors:
            return False
        return True

    def get_errors(self):
        return {"message": self.__errors}

    def add_field_errors(self, field: Field):
        self.__errors += field.get_errors()

    def add_error(self, error):
        self.__errors.append(error)


class OTPSchema(BaseSchema):
    def __init__(self, data: dict, fields=None):
        self.name = Field("name", str, data.get("name", KeyError))
        self.realm = Field("realm", str, data.get("realm", KeyError))
        self.token = Field("token", str, data.get("token", KeyError))

        _fields = [self.name, self.realm, self.token]
        if fields:
            _fields += fields
        super().__init__(data=data, fields=_fields)

    def validation(self):
        if check_otp(self.name.value, self.realm.value, self.token.value) != 200:
            self.add_error("Wrong Credentials")


########################## Image Operations ###############################################


class CreateImageSchema(BaseSchema):
    def __init__(self, data):
        # Fields
        self.uuid = Field("uuid", str, data.get("uuid", KeyError))
        self.name = Field("name", str, data.get("name", KeyError))
        self.image_store = Field("image_store", str, data.get("image_store", KeyError))

        # Validations
        self.uuid.validation = self.file_uuid_validation
        self.image_store.validation = self.image_store_name_validation

        # All Fields
        fields = [self.uuid, self.name, self.image_store]
        super().__init__(data, fields)

    def file_uuid_validation(self):
        file_entry = etcd_client.get(os.path.join(config['etcd']['FILE_PREFIX'], self.uuid.value))
        if file_entry is None:
            self.add_error(
                "Image File with uuid '{}' Not Found".format(self.uuid.value)
            )

    def image_store_name_validation(self):
        image_stores = list(etcd_client.get_prefix(config['etcd']['IMAGE_STORE_PREFIX']))

        image_store = next(
            filter(
                lambda s: json.loads(s.value)["name"] == self.image_store.value,
                image_stores,
            ),
            None,
        )
        if not image_store:
            self.add_error("Store '{}' does not exists".format(self.image_store.value))


# Host Operations

class CreateHostSchema(OTPSchema):
    def __init__(self, data):
        self.parsed_specs = {}
        # Fields
        self.specs = Field("specs", dict, data.get("specs", KeyError))
        self.hostname = Field("hostname", str, data.get("hostname", KeyError))

        # Validation
        self.specs.validation = self.specs_validation

        fields = [self.hostname, self.specs]

        super().__init__(data=data, fields=fields)

    def specs_validation(self):
        ALLOWED_BASE = 10

        _cpu = self.specs.value.get('cpu', KeyError)
        _ram = self.specs.value.get('ram', KeyError)
        _os_ssd = self.specs.value.get('os-ssd', KeyError)
        _hdd = self.specs.value.get('hdd', KeyError)

        if KeyError in [_cpu, _ram, _os_ssd, _hdd]:
            self.add_error("You must specify CPU, RAM and OS-SSD in your specs")
            return None
        try:
            parsed_ram = bitmath.parse_string_unsafe(_ram)
            parsed_os_ssd = bitmath.parse_string_unsafe(_os_ssd)

            if parsed_ram.base != ALLOWED_BASE:
                self.add_error("Your specified RAM is not in correct units")
            if parsed_os_ssd.base != ALLOWED_BASE:
                self.add_error("Your specified OS-SSD is not in correct units")

            if _cpu < 1:
                self.add_error("CPU must be atleast 1")

            if parsed_ram < bitmath.GB(1):
                self.add_error("RAM must be atleast 1 GB")

            if parsed_os_ssd < bitmath.GB(10):
                self.add_error("OS-SSD must be atleast 10 GB")

            parsed_hdd = []
            for hdd in _hdd:
                _parsed_hdd = bitmath.parse_string_unsafe(hdd)
                if _parsed_hdd.base != ALLOWED_BASE:
                    self.add_error("Your specified HDD is not in correct units")
                    break
                else:
                    parsed_hdd.append(str(_parsed_hdd))

        except ValueError:
            # TODO: Find some good error message
            self.add_error("Specs are not correct.")
        else:
            if self.get_errors():
                self.specs = {
                    'cpu': _cpu,
                    'ram': str(parsed_ram),
                    'os-ssd': str(parsed_os_ssd),
                    'hdd': parsed_hdd
                }

    def validation(self):
        if self.realm.value != "ungleich-admin":
            self.add_error("Invalid Credentials/Insufficient Permission")


# VM Operations


class CreateVMSchema(OTPSchema):
    def __init__(self, data):
        self.parsed_specs = {}

        # Fields
        self.specs = Field("specs", dict, data.get("specs", KeyError))
        self.vm_name = Field("vm_name", str, data.get("vm_name", KeyError))
        self.image = Field("image", str, data.get("image", KeyError))
        self.network = Field("network", list, data.get("network", KeyError))

        # Validation
        self.image.validation = self.image_validation
        self.vm_name.validation = self.vm_name_validation
        self.specs.validation = self.specs_validation
        self.network.validation = self.network_validation

        fields = [self.vm_name, self.image, self.specs, self.network]

        super().__init__(data=data, fields=fields)

    def image_validation(self):
        try:
            image_uuid = helper.resolve_image_name(self.image.value, etcd_client)
        except Exception as e:
            self.add_error(str(e))
        else:
            self.image_uuid = image_uuid

    def vm_name_validation(self):
        if resolve_vm_name(name=self.vm_name.value, owner=self.name.value):
            self.add_error(
                'VM with same name "{}" already exists'.format(self.vm_name.value)
            )

    def network_validation(self):
        _network = self.network.value

        if _network:
            for net in _network:
                network = etcd_client.get(os.path.join(config['etcd']['NETWORK_PREFIX'],
                                                       self.name.value,
                                                       net), value_in_json=True)
                if not network:
                    self.add_error("Network with name {} does not exists" \
                                   .format(net))

    def specs_validation(self):
        ALLOWED_BASE = 10

        _cpu = self.specs.value.get('cpu', KeyError)
        _ram = self.specs.value.get('ram', KeyError)
        _os_ssd = self.specs.value.get('os-ssd', KeyError)
        _hdd = self.specs.value.get('hdd', KeyError)

        if KeyError in [_cpu, _ram, _os_ssd, _hdd]:
            self.add_error("You must specify CPU, RAM and OS-SSD in your specs")
            return None
        try:
            parsed_ram = bitmath.parse_string_unsafe(_ram)
            parsed_os_ssd = bitmath.parse_string_unsafe(_os_ssd)

            if parsed_ram.base != ALLOWED_BASE:
                self.add_error("Your specified RAM is not in correct units")
            if parsed_os_ssd.base != ALLOWED_BASE:
                self.add_error("Your specified OS-SSD is not in correct units")

            if _cpu < 1:
                self.add_error("CPU must be atleast 1")

            if parsed_ram < bitmath.GB(1):
                self.add_error("RAM must be atleast 1 GB")

            if parsed_os_ssd < bitmath.GB(1):
                self.add_error("OS-SSD must be atleast 1 GB")

            parsed_hdd = []
            for hdd in _hdd:
                _parsed_hdd = bitmath.parse_string_unsafe(hdd)
                if _parsed_hdd.base != ALLOWED_BASE:
                    self.add_error("Your specified HDD is not in correct units")
                    break
                else:
                    parsed_hdd.append(str(_parsed_hdd))

        except ValueError:
            # TODO: Find some good error message
            self.add_error("Specs are not correct.")
        else:
            if self.get_errors():
                self.specs = {
                    'cpu': _cpu,
                    'ram': str(parsed_ram),
                    'os-ssd': str(parsed_os_ssd),
                    'hdd': parsed_hdd
                }


class VMStatusSchema(OTPSchema):
    def __init__(self, data):
        data["uuid"] = (
                resolve_vm_name(
                    name=data.get("vm_name", None),
                    owner=(data.get("in_support_of", None) or data.get("name", None)),
                )
                or KeyError
        )
        self.uuid = VmUUIDField(data)

        fields = [self.uuid]

        super().__init__(data, fields)

    def validation(self):
        vm = vm_pool.get(self.uuid.value)
        if not (
                vm.value["owner"] == self.name.value or self.realm.value == "ungleich-admin"
        ):
            self.add_error("Invalid User")


class VmActionSchema(OTPSchema):
    def __init__(self, data):
        data["uuid"] = (
                resolve_vm_name(
                    name=data.get("vm_name", None),
                    owner=(data.get("in_support_of", None) or data.get("name", None)),
                )
                or KeyError
        )
        self.uuid = VmUUIDField(data)
        self.action = Field("action", str, data.get("action", KeyError))

        self.action.validation = self.action_validation

        _fields = [self.uuid, self.action]

        super().__init__(data=data, fields=_fields)

    def action_validation(self):
        allowed_actions = ["start", "stop", "delete"]
        if self.action.value not in allowed_actions:
            self.add_error(
                "Invalid Action. Allowed Actions are {}".format(allowed_actions)
            )

    def validation(self):
        vm = vm_pool.get(self.uuid.value)
        if not (
                vm.value["owner"] == self.name.value or self.realm.value == "ungleich-admin"
        ):
            self.add_error("Invalid User")

        if (
                self.action.value == "start"
                and vm.status == VMStatus.running
                and vm.hostname != ""
        ):
            self.add_error("VM Already Running")

        if self.action.value == "stop":
            if vm.status == VMStatus.stopped:
                self.add_error("VM Already Stopped")
            elif vm.status != VMStatus.running:
                self.add_error("Cannot stop non-running VM")


class VmMigrationSchema(OTPSchema):
    def __init__(self, data):
        data["uuid"] = (
                resolve_vm_name(
                    name=data.get("vm_name", None),
                    owner=(data.get("in_support_of", None) or data.get("name", None)),
                )
                or KeyError
        )

        self.uuid = VmUUIDField(data)
        self.destination = Field("destination", str, data.get("destination", KeyError))

        self.destination.validation = self.destination_validation

        fields = [self.destination]
        super().__init__(data=data, fields=fields)

    def destination_validation(self):
        hostname = self.destination.value
        host = next(filter(lambda h: h.hostname == hostname, host_pool.hosts), None)
        if not host:
            self.add_error("No Such Host ({}) exists".format(self.destination.value))
        elif host.status != HostStatus.alive:
            self.add_error("Destination Host is dead")
        else:
            self.destination.value = host.key

    def validation(self):
        vm = vm_pool.get(self.uuid.value)
        if not (
                vm.value["owner"] == self.name.value or self.realm.value == "ungleich-admin"
        ):
            self.add_error("Invalid User")

        if vm.status != VMStatus.running:
            self.add_error("Can't migrate non-running VM")

        if vm.hostname == os.path.join(config['etcd']['HOST_PREFIX'], self.destination.value):
            self.add_error("Destination host couldn't be same as Source Host")


class AddSSHSchema(OTPSchema):
    def __init__(self, data):
        self.key_name = Field("key_name", str, data.get("key_name", KeyError))
        self.key = Field("key", str, data.get("key_name", KeyError))

        fields = [self.key_name, self.key]
        super().__init__(data=data, fields=fields)


class RemoveSSHSchema(OTPSchema):
    def __init__(self, data):
        self.key_name = Field("key_name", str, data.get("key_name", KeyError))

        fields = [self.key_name]
        super().__init__(data=data, fields=fields)


class GetSSHSchema(OTPSchema):
    def __init__(self, data):
        self.key_name = Field("key_name", str, data.get("key_name", None))

        fields = [self.key_name]
        super().__init__(data=data, fields=fields)


class CreateNetwork(OTPSchema):
    def __init__(self, data):
        self.network_name = Field("network_name", str, data.get("network_name", KeyError))
        self.type = Field("type", str, data.get("type", KeyError))
        self.user = Field("user", bool, bool(data.get("user", False)))

        self.network_name.validation = self.network_name_validation
        self.type.validation = self.network_type_validation

        fields = [self.network_name, self.type, self.user]
        super().__init__(data, fields=fields)

    def network_name_validation(self):
        network = etcd_client.get(os.path.join(config['etcd']['NETWORK_PREFIX'],
                                               self.name.value,
                                               self.network_name.value),
                                  value_in_json=True)
        if network:
            self.add_error("Network with name {} already exists" \
                           .format(self.network_name.value))

    def network_type_validation(self):
        supported_network_types = ["vxlan"]
        if self.type.value not in supported_network_types:
            self.add_error("Unsupported Network Type. Supported network types are {}".format(supported_network_types))