from django.db import models
from django.db.models import Q
from django.contrib.auth import get_user_model
from django.utils.translation import gettext_lazy as _
from django.core.validators import MinValueValidator
from django.utils import timezone
from django.core.exceptions import ObjectDoesNotExist, ValidationError

from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType

import logging
from functools import reduce
import itertools
from math import ceil
import datetime
from calendar import monthrange
from decimal import Decimal

import uncloud_pay.stripe
from uncloud_pay.helpers import beginning_of_month, end_of_month
from uncloud_pay import AMOUNT_DECIMALS, AMOUNT_MAX_DIGITS, COUNTRIES
from uncloud.models import UncloudModel, UncloudStatus

from decimal import Decimal
import decimal

# Used to generate bill due dates.
BILL_PAYMENT_DELAY=datetime.timedelta(days=10)

# Initialize logger.
logger = logging.getLogger(__name__)

def start_of_month(a_day):
    """ Returns first of the month of a given datetime object"""
    return a_day.replace(day=1,hour=0,minute=0,second=0, microsecond=0)

def end_of_month(a_day):
    """ Returns first of the month of a given datetime object"""

    _, last_day = monthrange(a_day.year, a_day.month)
    return a_day.replace(day=last_day,hour=23,minute=59,second=59, microsecond=0)

def start_of_this_month():
    """ Returns first of this month"""
    a_day = timezone.now()
    return a_day.replace(day=1,hour=0,minute=0,second=0, microsecond=0)

def end_of_this_month():
    """ Returns first of this month"""
    a_day = timezone.now()

    _, last_day = monthrange(a_day.year, a_day.month)
    return a_day.replace(day=last_day,hour=23,minute=59,second=59, microsecond=0)

def default_payment_delay():
    return timezone.now() + BILL_PAYMENT_DELAY

# See https://docs.djangoproject.com/en/dev/ref/models/fields/#field-choices-enum-types
class RecurringPeriod(models.IntegerChoices):
    """
    We don't support months are years, because they vary in length.
    This is not only complicated, but also unfair to the user, as the user pays the same
    amount for different durations.
    """
    PER_365D   = 365*24*3600, _('Per 365 days')
    PER_30D    = 30*24*3600, _('Per 30 days')
    PER_WEEK   = 7*24*3600, _('Per Week')
    PER_DAY    = 24*3600, _('Per Day')
    PER_HOUR   = 3600, _('Per Hour')
    PER_MINUTE = 60, _('Per Minute')
    PER_SECOND = 1, _('Per Second')
    ONE_TIME   = 0, _('Onetime')

class CountryField(models.CharField):
    def __init__(self, *args, **kwargs):
        kwargs.setdefault('choices', COUNTRIES)
        kwargs.setdefault('default', 'CH')
        kwargs.setdefault('max_length', 2)

        super().__init__(*args, **kwargs)

    def get_internal_type(self):
        return "CharField"

def get_balance_for_user(user):
    bills = reduce(
            lambda acc, entry: acc + entry.total,
            Bill.objects.filter(owner=user),
            0)
    payments = reduce(
            lambda acc, entry: acc + entry.amount,
            Payment.objects.filter(owner=user),
            0)
    return payments - bills

class StripeCustomer(models.Model):
    owner = models.OneToOneField( get_user_model(),
            primary_key=True,
            on_delete=models.CASCADE)
    stripe_id = models.CharField(max_length=32)

###
# Payments and Payment Methods.

class Payment(models.Model):
    owner = models.ForeignKey(get_user_model(),
            on_delete=models.CASCADE)

    amount = models.DecimalField(
            default=0.0,
            max_digits=AMOUNT_MAX_DIGITS,
            decimal_places=AMOUNT_DECIMALS,
            validators=[MinValueValidator(0)])

    source = models.CharField(max_length=256,
                              choices = (
                                  ('wire', 'Wire Transfer'),
                                  ('stripe', 'Stripe'),
                                  ('voucher', 'Voucher'),
                                  ('referral', 'Referral'),
                                  ('unknown', 'Unknown')
                              ),
                              default='unknown')
    timestamp = models.DateTimeField(editable=False, auto_now_add=True)

    # We override save() in order to active products awaiting payment.
    def save(self, *args, **kwargs):
        # _state.adding is switched to false after super(...) call.
        being_created = self._state.adding

        unpaid_bills_before_payment = Bill.get_unpaid_for(self.owner)
        super(Payment, self).save(*args, **kwargs) # Save payment in DB.
        unpaid_bills_after_payment = Bill.get_unpaid_for(self.owner)

        newly_paid_bills = list(
                set(unpaid_bills_before_payment) - set(unpaid_bills_after_payment))
        for bill in newly_paid_bills:
            bill.activate_products()


class PaymentMethod(models.Model):
    owner = models.ForeignKey(get_user_model(),
            on_delete=models.CASCADE,
            editable=False)
    source = models.CharField(max_length=256,
            choices = (
                ('stripe', 'Stripe'),
                ('unknown', 'Unknown'),
                ),
            default='stripe')
    description = models.TextField()
    primary = models.BooleanField(default=False, editable=False)

    # Only used for "Stripe" source
    stripe_payment_method_id = models.CharField(max_length=32, blank=True, null=True)
    stripe_setup_intent_id = models.CharField(max_length=32, blank=True, null=True)

    @property
    def stripe_card_last4(self):
        if self.source == 'stripe' and self.active:
            payment_method = uncloud_pay.stripe.get_payment_method(
                    self.stripe_payment_method_id)
            return payment_method.card.last4
        else:
            return None

    @property
    def active(self):
        if self.source == 'stripe' and self.stripe_payment_method_id != None:
            return True
        else:
            return False

    def charge(self, amount):
        if not self.active:
            raise Exception('This payment method is inactive.')

        if amount < 0: # Make sure we don't charge negative amount by errors...
            raise Exception('Cannot charge negative amount.')

        if self.source == 'stripe':
            stripe_customer = StripeCustomer.objects.get(owner=self.owner).stripe_id
            stripe_payment = uncloud_pay.stripe.charge_customer(
                    amount, stripe_customer, self.stripe_payment_method_id)
            if 'paid' in stripe_payment and stripe_payment['paid'] == False:
                raise Exception(stripe_payment['error'])
            else:
                payment = Payment.objects.create(
                        owner=self.owner, source=self.source, amount=amount)

                return payment
        else:
            raise Exception('This payment method is unsupported/cannot be charged.')

    def set_as_primary_for(self, user):
        methods = PaymentMethod.objects.filter(owner=user, primary=True)
        for method in methods:
            print(method)
            method.primary = False
            method.save()

        self.primary = True
        self.save()

    def get_primary_for(user):
        methods = PaymentMethod.objects.filter(owner=user)
        for method in methods:
            # Do we want to do something with non-primary method?
            if method.active and method.primary:
                return method

        return None

    class Meta:
        # TODO: limit to one primary method per user.
        # unique_together is no good since it won't allow more than one
        # non-primary method.
        pass

###
# Bills.

class BillingAddress(models.Model):
    owner = models.ForeignKey(get_user_model(), on_delete=models.CASCADE)

    organization = models.CharField(max_length=100, blank=True, null=True)
    name = models.CharField(max_length=100)
    street = models.CharField(max_length=100)
    city = models.CharField(max_length=50)
    postal_code = models.CharField(max_length=50)
    country = CountryField(blank=True)
    vat_number = models.CharField(max_length=100, default="", blank=True)
    active = models.BooleanField(default=False)

    class Meta:
        constraints = [
            models.UniqueConstraint(fields=['owner'],
                                    condition=Q(active=True),
                                    name='one_active_billing_address_per_user')
        ]

    @staticmethod
    def get_address_for(user):
        return BillingAddress.objects.get(owner=user, active=True)

    def __str__(self):
        return "{} - {}, {}, {} {}, {}".format(
            self.owner,
            self.name, self.street, self.postal_code, self.city,
            self.country)

###
# VAT

class VATRate(models.Model):
    starting_date = models.DateField(blank=True, null=True)
    ending_date = models.DateField(blank=True, null=True)
    territory_codes = models.TextField(blank=True, default='')
    currency_code = models.CharField(max_length=10)
    rate = models.FloatField()
    rate_type = models.TextField(blank=True, default='')
    description = models.TextField(blank=True, default='')

    @staticmethod
    def get_for_country(country_code):
        vat_rate = None
        try:
            vat_rate = VATRate.objects.get(
                territory_codes=country_code, start_date__isnull=False, stop_date=None
            )
            return vat_rate.rate
        except VATRate.DoesNotExist as dne:
            logger.debug(str(dne))
            logger.debug("Did not find VAT rate for %s, returning 0" % country_code)
            return 0

###
# Orders.

class Order(models.Model):
    """
    Order are assumed IMMUTABLE and used as SOURCE OF TRUST for generating
    bills. Do **NOT** mutate then!
    """

    owner = models.ForeignKey(get_user_model(),
                              on_delete=models.CASCADE,
                              editable=True)

    billing_address = models.ForeignKey(BillingAddress,
                                        on_delete=models.CASCADE)

    description = models.TextField()

    # TODO: enforce ending_date - starting_date to be larger than recurring_period.
    creation_date = models.DateTimeField(auto_now_add=True)
    starting_date = models.DateTimeField(default=timezone.now)
    ending_date = models.DateTimeField(blank=True, null=True)

    recurring_period = models.IntegerField(choices = RecurringPeriod.choices,
                                           default = RecurringPeriod.PER_30D)

    one_time_price = models.DecimalField(default=0.0,
            max_digits=AMOUNT_MAX_DIGITS,
            decimal_places=AMOUNT_DECIMALS,
            validators=[MinValueValidator(0)])

    recurring_price = models.DecimalField(default=0.0,
            max_digits=AMOUNT_MAX_DIGITS,
            decimal_places=AMOUNT_DECIMALS,
            validators=[MinValueValidator(0)])

    replaces = models.ForeignKey('self',
                                 related_name='replaced_by',
                                 on_delete=models.PROTECT,
                                 blank=True,
                                 null=True)

    depends_on = models.ForeignKey('self',
                                   related_name='parent_of',
                                   on_delete=models.PROTECT,
                                   blank=True,
                                   null=True)


    @property
    def count_billed(self):
        """
        How many times this order was billed so far.
        This logic is mainly thought to be for recurring bills, but also works for one time bills
        """

        return sum([ br.quantity for br in self.bill_records.all() ])


    def active_before(self, ending_date):
        # Was this order started before the specified ending date?
        if self.starting_date <= ending_date:
            if self.ending_date:
                if self.ending_date > ending_date:
                    pass

    @property
    def is_recurring(self):
        return not self.recurring_period == RecurringPeriod.ONE_TIME

    @property
    def is_terminated(self):
        return self.ending_date != None and self.ending_date < timezone.now()

    def is_terminated_at(self, a_date):
        return self.ending_date != None and self.ending_date < timezone.now()

    def terminate(self):
        if not self.is_terminated:
            self.ending_date = timezone.now()
            self.save()

    # Trigger initial bill generation at order creation.
    def save(self, *args, **kwargs):
        if self.ending_date and self.ending_date < self.starting_date:
            raise ValidationError("End date cannot be before starting date")

        super().save(*args, **kwargs)

    def generate_initial_bill(self):
        return Bill.generate_for(self.starting_date.year, self.starting_date.month, self.owner)

    @property
    def records(self):
        return OrderRecord.objects.filter(order=self)

    def __str__(self):
        return f"Order {self.owner}-{self.id}"


class Bill(models.Model):
    owner = models.ForeignKey(get_user_model(),
            on_delete=models.CASCADE)

    creation_date = models.DateTimeField(auto_now_add=True)
    starting_date = models.DateTimeField(default=start_of_this_month)
    ending_date = models.DateTimeField()
    due_date = models.DateField(default=default_payment_delay)

    # what is valid for? should this be "final"?
    valid = models.BooleanField(default=True)

    # Mapping to BillRecords
    # https://stackoverflow.com/questions/4443190/djangos-manytomany-relationship-with-additional-fields
    bill_records = models.ManyToManyField(Order, through="BillRecord")

    class Meta:
        constraints = [
            models.UniqueConstraint(fields=['owner',
                                            'starting_date',
                                            'ending_date' ],
                                    name='one_bill_per_month_per_user')
        ]

    def __str__(self):
        return f"Bill {self.owner}-{self.id}"


    @classmethod
    def create_next_bill_for_user(cls, user):
        last_bill = cls.objects.filter(owner=user).order_by('id').last()
        all_orders = Order.objects.filter(owner=user).order_by('id')
        first_order = all_orders.first()


        # Calculate the start date
        if last_bill:
            starting_date = last_bill.end_date + datetime.timedelta(seconds=1)
        else:
            if first_order:
                starting_date = first_order.starting_date
            else:
                starting_date = timezone.now()

        ending_date = end_of_month(starting_date)

        for order in all_orders:
            # check if order needs to be billed
            # check if order has previous billing record

            pass

    @classmethod
    def create_all_bills(cls):
        for owner in get_user_model().objects.all():
            # mintime = time of first order
            # maxtime = time of last order
            # iterate month based through it

            cls.assign_orders_to_bill(owner, year, month)
            pass

    def assign_orders_to_bill(self, owner, year, month):
        """
        Generate a bill for the specific month of a user.

        First handle all one time orders

        FIXME:

        - limit this to active users in the future! (2020-05-23)
        """

        """
        Find all one time orders that have a starting date that falls into this month
                                          recurring_period=RecurringPeriod.ONE_TIME,

        Can we do this even for recurring / all of them

        """

        # FIXME: add something to check  whether the order should be billed at all - i.e. a marker that
        # disables searching -> optimization for later
        # Create the initial bill record
        # FIXME: maybe limit not even to starting/ending date, but to empty_bill record -- to be fixed in the future
        # for order in Order.objects.filter(Q(starting_date__gte=self.starting_date),
        #                                   Q(starting_date__lte=self.ending_date),

        # FIXME below: only check for active orders

        # Ensure all orders of that owner have at least one bill record
        for order in Order.objects.filter(owner=owner,
                                          bill_records=None):

            bill_record = BillRecord.objects.create(bill=self,
                                                    quantity=1,
                                                    starting_date=order.starting_date,
                                                    ending_date=order.starting_date + timedelta(seconds=order.recurring_period))


        # For each recurring order get the usage and bill it
        for order in Order.objects.filter(~Q(recurring_period=RecurringPeriod.ONE_TIME),
                                          Q(starting_date__lt=self.starting_date),
                                          owner=owner):

            if order.recurring_period > 0: # avoid div/0 - these are one time payments

                # How much time will have passed by the end of the billing cycle
                td = self.ending_date - order.starting_date

                # How MANY times it will have been used by then
                used_times = ceil(td / timedelta(seconds=order.recurring_period))

                billed_times = len(order.bills)

                # How many times it WAS billed -- can also be inferred from the bills that link to it!
                if used_times > billed_times:
                    billing_times = used_times - billed_times

                    # ALSO REGISTER THE TIME PERIOD!
                    pass




class BillRecord(models.Model):
    """
    Entry of a bill, dynamically generated from an order.
    """

    bill = models.ForeignKey(Bill, on_delete=models.CASCADE)
    order = models.ForeignKey(Order, on_delete=models.CASCADE)

    # How many times the order has been used in this record
    quantity = models.DecimalField(max_digits=19, decimal_places=10)

    # The timeframe the bill record is for can (and probably often will) differ
    # from the bill time

    creation_date = models.DateTimeField(auto_now_add=True)
    starting_date = models.DateTimeField()
    ending_date = models.DateTimeField()

    def __str__(self):
        return f"{self.bill}: {self.quantity} x {self.order}"


class OrderRecord(models.Model):
    """
    Order records store billing informations for products: the actual product
    might be mutated and/or moved to another order but we do not want to loose
    the details of old orders.

    Used as source of trust to dynamically generate bill entries.
    """

    order = models.ForeignKey(Order, on_delete=models.CASCADE)

    one_time_price = models.DecimalField(default=0.0,
            max_digits=AMOUNT_MAX_DIGITS,
            decimal_places=AMOUNT_DECIMALS,
            validators=[MinValueValidator(0)])
    recurring_price = models.DecimalField(default=0.0,
            max_digits=AMOUNT_MAX_DIGITS,
            decimal_places=AMOUNT_DECIMALS,
            validators=[MinValueValidator(0)])

    description = models.TextField()


    @property
    def recurring_period(self):
        return self.order.recurring_period

    @property
    def starting_date(self):
        return self.order.starting_date

    @property
    def ending_date(self):
        return self.order.ending_date


###
# Products

class Product(UncloudModel):
    owner = models.ForeignKey(get_user_model(),
                              on_delete=models.CASCADE,
                              editable=False)

    description = "Generic Product"

    order = models.ForeignKey(Order,
                              on_delete=models.CASCADE,
                              editable=True,
                              null=True)
    # FIXME: editable=True -> is in the admin, but also editable in DRF

    status = models.CharField(max_length=32,
            choices=UncloudStatus.choices,
            default=UncloudStatus.AWAITING_PAYMENT)

    # Default period for all products
    default_recurring_period = RecurringPeriod.PER_30D

    def create_order_at(self, when_to_start, *args, **kwargs):
        billing_address = BillingAddress.get_address_for(self.owner)

        order = Order.objects.create(owner=self.owner,
                                          billing_address=billing_address,
                                          starting_date=when_to_start,
                                          one_time_price=self.one_time_price,
                                          recurring_period=self.default_recurring_period,
                                          recurring_price=self.recurring_price,
                                          description=str(self))

    def create_or_update_order(self, when_to_start=None):
        if not when_to_start:
            when_to_start = timezone.now()

        if not self.order:
            billing_address = BillingAddress.get_address_for(self.owner)

            if not billing_address:
                raise ValidationError("Cannot order without a billing address")

            self.order = Order.objects.create(owner=self.owner,
                                              billing_address=billing_address,
                                              starting_date=when_to_start,
                                              one_time_price=self.one_time_price,
                                              recurring_period=self.default_recurring_period,
                                              recurring_price=self.recurring_price,
                                              description=str(self))

        else:
            previous_order = self.order
            when_to_end = when_to_start - datetime.timedelta(seconds=1)

            new_order = Order.objects.create(owner=self.owner,
                                             billing_address=self.order.billing_address,
                                             starting_date=when_to_start,
                                             one_time_price=self.one_time_price,
                                             recurring_period=self.default_recurring_period,
                                             recurring_price=self.recurring_price,
                                             description=str(self),
                                             replaces=self.order)

            self.order.end_date = when_to_end
            self.order.save()

            self.order = new_order


    def save(self, *args, **kwargs):
        # Create order if there is none already
        if not self.order:
            self.create_or_update_order()

        super().save(*args, **kwargs)

    @property
    def recurring_price(self):
        pass # To be implemented in child.

    @property
    def one_time_price(self):
        """
        Default is 0 CHF
        """
        return 0

    @property
    def billing_address(self):
        return self.order.billing_address

    @staticmethod
    def allowed_recurring_periods():
        return RecurringPeriod.choices

    class Meta:
        abstract = True

    def discounted_price_by_period(self, requested_period):
        """
        Each product has a standard recurring period for which
        we define a pricing. I.e. VPN is usually year, VM is usually monthly.

        The user can opt-in to use a different period, which influences the price:
        The longer a user commits, the higher the discount.

        Products can also be limited in the available periods. For instance
        a VPN only makes sense to be bought for at least one day.

        Rules are as follows:

        given a standard recurring period of ..., changing to ... modifies price ...


        # One month for free if buying / year, compared to a month: about 8.33% discount
        per_year -> per_month -> /11
        per_month -> per_year -> *11

        # Month has 30.42 days on average. About 7.9% discount to go monthly
        per_month -> per_day    -> /28
        per_day   -> per_month  -> *28

        # Day has 24h, give one for free
        per_day   -> per_hour   -> /23
        per_hour   -> per_day   -> /23


        Examples

        VPN @ 120CHF/y becomes
        - 10.91 CHF/month (130.91 CHF/year)
        - 0.39 CHF/day (142.21 CHF/year)

        VM @ 15 CHF/month becomes
        - 165 CHF/month (13.75 CHF/month)
        - 0.54 CHF/day (16.30 CHF/month)

        """


        if self.default_recurring_period == RecurringPeriod.PER_365D:
            if requested_period == RecurringPeriod.PER_365D:
                return self.recurring_price
            if requested_period == RecurringPeriod.PER_30D:
                return self.recurring_price/11.
            if requested_period == RecurringPeriod.PER_DAY:
                return self.recurring_price/11./28.

        elif self.default_recurring_period == RecurringPeriod.PER_30D:
            if requested_period == RecurringPeriod.PER_365D:
                return self.recurring_price*11
            if requested_period == RecurringPeriod.PER_30D:
                return self.recurring_price
            if requested_period == RecurringPeriod.PER_DAY:
                return self.recurring_price/28.

        elif self.default_recurring_period == RecurringPeriod.PER_DAY:
            if requested_period == RecurringPeriod.PER_365D:
                return self.recurring_price*11*28
            if requested_period == RecurringPeriod.PER_30D:
                return self.recurring_price*28
            if requested_period == RecurringPeriod.PER_DAY:
                return self.recurring_price
        else:
            # FIXME: use the right type of exception here!
            raise Exception("Did not implement the discounter for this case")