#!/usr/bin/env python

# from switch_calc_headers import *
from scapy.all import *
from nf_sim_tools import *
from collections import OrderedDict
import sss_sdnet_tuples

########################
# pkt generation tools #
########################

pktsApplied = []
pktsExpected = []

# Pkt lists for SUME simulations
nf_applied = OrderedDict()
nf_applied[0] = []
nf_applied[1] = []
nf_applied[2] = []
nf_applied[3] = []
nf_expected = OrderedDict()
nf_expected[0] = []
nf_expected[1] = []
nf_expected[2] = []
nf_expected[3] = []

nf_port_map = {
    "nf0":  0b00000001,
    "nf1":  0b00000100,
    "nf2":  0b00010000,
    "nf3":  0b01000000,
    "dma0": 0b00000010,
    "bcast":0b01010101
}

nf_id_map = {
    "nf0":0,
    "nf1":1,
    "nf2":2,
    "nf3":3
}

sss_sdnet_tuples.clear_tuple_files()

def applyPkt(pkt, ingress, time):
    pktsApplied.append(pkt)
    sss_sdnet_tuples.sume_tuple_in['src_port'] = nf_port_map[ingress]
    sss_sdnet_tuples.sume_tuple_expect['src_port'] = nf_port_map[ingress]
    pkt.time = time
    nf_applied[nf_id_map[ingress]].append(pkt)

def expPkt(pkt, egress):
    pktsExpected.append(pkt)
    sss_sdnet_tuples.sume_tuple_expect['dst_port'] = nf_port_map[egress]
    sss_sdnet_tuples.write_tuples()
    if egress in ["nf0","nf1","nf2","nf3"]:
        nf_expected[nf_id_map[egress]].append(pkt)
    elif egress == 'bcast':
        nf_expected[0].append(pkt)
        nf_expected[1].append(pkt)
        nf_expected[2].append(pkt)
        nf_expected[3].append(pkt)

def print_summary(pkts):
    for pkt in pkts:
        print "summary = ", pkt.summary()

def write_pcap_files():
    wrpcap("src.pcap", pktsApplied)
    wrpcap("dst.pcap", pktsExpected)

    for i in nf_applied.keys():
        if (len(nf_applied[i]) > 0):
            wrpcap('nf{0}_applied.pcap'.format(i), nf_applied[i])

    for i in nf_expected.keys():
        if (len(nf_expected[i]) > 0):
            wrpcap('nf{0}_expected.pcap'.format(i), nf_expected[i])

    for i in nf_applied.keys():
        print "nf{0}_applied times: ".format(i), [p.time for p in nf_applied[i]]

#####################
# generate testdata #
#####################

MACSRC = "08:11:11:11:11:08"
MAC0   = "08:22:22:22:22:00"
MAC1   = "08:22:22:22:22:01"
MAC2   = "08:22:22:22:22:02"
MAC3   = "08:22:22:22:22:03"

pktCnt = 0

INDEX_WIDTH = 4
REG_DEPTH = 2**INDEX_WIDTH

# Not sure what this is used for
NUM_KEYS = 4
lookup_table = {
    0: 0x00000001,
    1: 0x00000010,
    2: 0x00000100,
    3: 0x00001000
}

def test_port1():
    pktCnt = 0

    # First ethernet
    pktCnt += 1
    pkt = Ether(dst=MAC2, src=MAC1)
    pkt = pad_pkt(pkt, 64)
    applyPkt(pkt, 'nf0', pktCnt)
    expPkt(pkt,   'nf0')

def test_all_ports():
    pktCnt = 0

    # First ethernet
    pkt = Ether(dst=MAC2, src=MAC1)
    pkt = pad_pkt(pkt, 64)

    pktCnt += 1
    applyPkt(pkt, 'nf1', pktCnt)
    expPkt(pkt,   'bcast')
    # expPkt(pkt,   'nf0')
    # expPkt(pkt,   'nf1')
    # expPkt(pkt,   'nf2')
    # expPkt(pkt,   'nf3')

    pktCnt += 1
    applyPkt(pkt, 'nf2', pktCnt)
    expPkt(pkt,   'bcast')
    # expPkt(pkt,   'nf0')
    # expPkt(pkt,   'nf1')
    # expPkt(pkt,   'nf2')
    # expPkt(pkt,   'nf3')

    pktCnt += 1
    applyPkt(pkt, 'nf3', pktCnt)
    expPkt(pkt,   'bcast')
    # expPkt(pkt,   'nf0')
    # expPkt(pkt,   'nf1')
    # expPkt(pkt,   'nf2')
    # expPkt(pkt,   'nf3')

# Test that packets are being mirrored
def test_mirror():
    pktCnt = 0

    # inject into nf1,2,3

    pktCnt += 1
    pkt = Ether(dst=MAC2, src=MAC1)
    pkt = pad_pkt(pkt, 64)
    applyPkt(pkt, 'nf1', pktCnt)

    pktCnt += 1
    pkt = Ether(dst=MAC1, src=MAC2)
    pkt = pad_pkt(pkt, 64)
    expPkt(pkt, 'nf0')

    # # Second IP
    # pktCnt += 1
    # pkt = Ether(dst=MAC2, src=MAC1) / IPv6(src="fe80::1", dst="fe80::2")
    # pkt = pad_pkt(pkt, 64)
    # applyPkt(pkt, 'nf0', pktCnt)
    # pktCnt += 1
    # pkt = Ether(dst=MAC1, src=MAC2)  / IPv6(src="fe80::2", dst="fe80::1")
    # pkt = pad_pkt(pkt, 64)
    # expPkt(pkt, 'nf0')

    # # Third tcp
    # pktCnt += 1
    # pkt = Ether(dst=MAC2, src=MAC1) / IPv6(src="fe80::1", dst="fe80::2") / TCP(sport=42, dport=23)
    # pkt = pad_pkt(pkt, 64)
    # applyPkt(pkt, 'nf0', pktCnt)
    # pktCnt += 1
    # pkt = Ether(dst=MAC1, src=MAC2)  / IPv6(src="fe80::2", dst="fe80::1") / TCP(sport=23, dport=42)
    # pkt = pad_pkt(pkt, 64)
    # expPkt(pkt, 'nf0')

#test_mirror()
#test_port1()

test_all_ports()
write_pcap_files()