from __future__ import unicode_literals

import nnpy
import struct

from p4utils.utils.topology import Topology
from p4utils.utils.sswitch_API import SimpleSwitchAPI

from scapy.all import sniff, get_if_list, Ether, get_if_hwaddr, sendp
from scapy.all import IP, Raw, IPv6, TCP, TCP_client
from scapy.all import Ether, sniff, Packet, BitField
from scapy.all import ICMPv6ND_NS

import sys
import re

import logging
import argparse
import subprocess

import ipaddress

logging.basicConfig()
log = logging.getLogger("main")



class CpuHeader(Packet):
    name = 'CpuPacket'
    fields_desc = [
        BitField('task',0,8)
        BitField('ingress_port', 0, 16)
    ]


class L2Controller(object):
    def __init__(self, sw_name):
        # Command line mapping
        self.modes = ['base', 'router']

        self.task = {
            'ICMP6_NS': 1,
            'ICMP6_GENERAL': 2,
            'DEBUG': 3
        }

        self.info={}
        self.info['ndp_multicast'] = ipaddress.ip_network("ff02::1:ff00:0/104")

        self.info['v6_mask'] = 64
        self.info['v6_base'] = ipaddress.ip_network("2001:db8::/32")
        self.info['v6_gen'] = self.info['v6_base'].subnets(new_prefix=self.info['v6_mask'])

        self.info['v4_mask'] = 24
        self.info['v4_base'] = ipaddress.ip_network("10.0.0.0/8")
        self.info['v4_gen'] = self.info['v4_base'].subnets(new_prefix=self.info['v4_mask'])

        self.info['switch_suffix'] = 42

        self.v6_routes = {}
        self.v6_routes[None] = []
        self.v6_routes['base'] = []

        for port in range(1,3):
            net = self.info['v6_gen'].next()
            self.v6_routes['base'].append({
                "net": net,
                "port": port}
            )

        self.v6_routes['router'] = self.v6_routes['base']

        self.v4_routes = {}
        self.v4_routes[None] = []
        self.v4_routes['base'] = []
        for port in range(3,5):
            net = self.info['v4_gen'].next()
            self.v4_routes['base'].append({
                "net": net,
                "port": port}
            )
        self.v4_routes['router'] = self.v4_routes['base']

        self.v6_addresses = {}
        self.v6_addresses[None] = []
        self.v6_addresses['base'] = []
        self.v6_addresses['router'] = [ net['net'][42] for net in self.v6_routes['router'] ]

        self.v4_addresses = {}
        self.v4_addresses[None] = []
        self.v4_addresses['base'] = []
        self.v4_addresses['router'] = [ net['net'][42] for net in self.v4_routes['router'] ]

        self.init_boilerplate(sw_name)

    def gen_ndp_multicast_addr(self, addr):
        """ append the 24 bit of the address to the multicast address"""

        last_24 = int(addr) & 0xffffff
        addr = self.info['ndp_multicast'][last_24]

        return addr

    def init_ndp(self):
        """ initialise neighbor discovery protocol"""

        # https://en.wikipedia.org/wiki/Solicited-node_multicast_address
        ndp_prefix = "ff02::1:ff00:0/104"

        all_ports = range(1,5)
        # create multicast nodes
        for rid in range(1,5):
            ports = [ x for x in all_ports if not x == rid ]
            n_handle = self.controller.mc_node_create(rid, ports)
            log.debug("Creating MC node rid={} ports={} handle={}".format(rid, ports, n_handle))

            g_handle = self.controller.mc_mgrp_create(rid)
            log.debug("Creating MC group mgrp={} handle={} && associating afterwards".format(rid, g_handle))

            self.controller.mc_node_associate(g_handle, n_handle)


        self.controller.table_clear("ndp")
        for port in all_ports:
            self.controller.table_add("ndp", "multicast_pkg", [ndp_prefix, str(port)], [str(port)])


        # Special rule for switch entries
        self.controller.table_add("ndp_answer", "icmp6_neighbor_solicitation", ["ff02::1:ff00:42", "135"], ["2001:db8:61::42"])

    def init_boilerplate(self, sw_name):
        self.topo = Topology(db="topology.db")
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(sw_name)
        self.cpu_port =  self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPI(self.thrift_port)
        self.intf = str(self.topo.get_cpu_port_intf(self.sw_name).replace("eth0", "eth1"))
        self.controller.reset_state()

        if self.cpu_port:
            self.controller.mirroring_add(100, self.cpu_port)

#        self.init_ndp()

    def config(self):
        self.fill_tables()
        self.config_hosts()

    def listen_to_icmp6_multicast(self):
        """Only needed for debugging"""

        net = self.info['ndp_multicast']
        self.controller.table_add("v6_networks", "controller_debug", [str(net)])

    def fill_tables(self):
        self.controller.table_clear("v6_networks")
        for v6route in self.v6_routes[self.mode]:
            self.controller.table_add("v6_networks", "set_egress_port", [str(v6route['net'])], [str(v6route['port'])])

        self.listen_to_icmp6_multicast()

        self.controller.table_clear("v4_routing")
        for v4route in self.v4_routes[self.mode]:
            self.controller.table_add("v4_networks", "set_egress_port", [str(v4route['net'])], [str(v4route['port'])])

        self.controller.table_clear("v6_addresses")
        for v6addr in self.v6_addresses[self.mode]:
            icmp6_addr = self.gen_ndp_multicast_addr(v6addr)

            self.controller.table_add("v6_addresses", "controller_reply", [str(self.task['ICMP6_GENERAL'])], [str(v6addr)])
            self.controller.table_add("v6_addresses", "controller_reply", [str(self.task['ICMP6_NS'])], [str(icmp6_addr)])

    def config_hosts(self):
        """ Assumptions:
        - all routes are networks (no /128 v6 or /32 v4
        - hosts get the first ip address in the network
        """

        for v6route in self.v6_routes[self.mode]:
            host   = "h{}".format(v6route['port'])
            dev    = "{}-eth0".format(host)
            net    = v6route['net']
            ipaddr = "{}/{}".format(net[1],net.prefixlen)

            self.add_host_ips(host, str(net), str(ipaddr), dev)

    @staticmethod
    def add_host_ips(host, net, ipaddr, dev):
        log.debug("Config host: {} {}->{} on {}".format(host, net, ipaddr, dev))

        subprocess.call(["mx", host, "ip", "addr", "flush", "dev", dev])
        for v6dev in [ "lo", "default", "all", dev ]:
            subprocess.call(["mx", host, "sysctl", "net.ipv6.conf.{}.disable_ipv6=0".format(v6dev)])

        # Set down & up to regain link local address
        subprocess.call(["mx", host, "ip", "link", "set", dev, "down"])
        subprocess.call(["mx", host, "ip", "link", "set", dev, "up"])

        # Now add global address
        subprocess.call(["mx", host, "ip", "addr", "add", ipaddr, "dev", dev])

    def debug_print_pkg(self, pkg, msg="INCOMING"):
        log.debug("{}: {}".format(msg, pkg.__repr__()))

    def debug_format_pkg(self, pkg):
        packet = Ether(str(pkg))

        if packet.type == 0x800:
            ip = pkg.getlayer(IP)
        elif packet.type == 0x86dd:
            ip = pkg.getlayer(IPv6)

        # tcp = pkg.getlayer(TCP)

        # raw = pkg.getlayer(Raw)

        # return "{}:{} => {}:{}: flags={} seq={} ack={} raw={}".format(
        #     ip.src, tcp.sport,
        #     ip.dst, tcp.dport,
        #     tcp.flags,
        #     tcp.seq,
        #     tcp.ack,
        #     raw)


    def recv_msg_cpu(self, pkg):
        packet = Ether(str(pkg))

        self.debug_print_pkg(pkg)

        if packet.type == 0x0800:
            pass
        elif packet.type == 0x86dd:
            pass
        elif packet.type == 0x4242:
            print("Special handling needed")
            pass
        elif packet.type == 0x2323:
            # Set back (incorrectly maybe) to IPv6
            packet.type = 0x86dd
            print("Debug pkg: ".format(packet))

        else:
            print("Broken pkg: {}".format(pkg.__repr__()))
            return

    def run_cpu_port_loop(self):
        sniff(iface=self.intf, prn=self.recv_msg_cpu)

    def commandline(self):
        parser = argparse.ArgumentParser(description='controller++')
        parser.add_argument('--mode', help='Select mode / settings to use', choices=self.modes)
        parser.add_argument('--debug', help='Enable debug logging', action='store_true')

        self.args = parser.parse_args()
        self.mode = self.args.mode
        self.debug = self.args.debug

if __name__ == "__main__":
    import sys
    import os

    if "DEBUG" in os.environ:
        log.setLevel(logging.DEBUG)
    else:
        log.setLevel(logging.INFO)

    log.info("Booting...")
    log.debug("Debug enabled.")

    sw_name = "s1"
    controller = L2Controller(sw_name)

    controller.commandline()
    if controller.args.debug:
        log.setLevel(logging.DEBUG)
    controller.config()
    controller.run_cpu_port_loop()