diff --git a/datacenterlight/views.py b/datacenterlight/views.py index ae649623..e13a31a6 100644 --- a/datacenterlight/views.py +++ b/datacenterlight/views.py @@ -26,7 +26,9 @@ from utils.forms import ( BillingAddressForm, BillingAddressFormSignup, UserBillingAddressForm, BillingAddress ) -from utils.hosting_utils import get_vm_price_with_vat, get_all_public_keys +from utils.hosting_utils import ( + get_vm_price_with_vat, get_all_public_keys, get_vat_rate_for_country +) from utils.stripe_utils import StripeUtils from utils.tasks import send_plain_email_task from .cms_models import DCLCalculatorPluginModel @@ -414,8 +416,9 @@ class PaymentOrderView(FormView): ) gp_details = { "product_name": product.product_name, - "amount": generic_payment_form.cleaned_data.get( - 'amount' + "amount": product.get_actual_price( + explicit_vat=get_vat_rate_for_country( + address_form["country"]) ), "recurring": generic_payment_form.cleaned_data.get( 'recurring' diff --git a/hosting/models.py b/hosting/models.py index 5f0ec3ef..00c89e11 100644 --- a/hosting/models.py +++ b/hosting/models.py @@ -82,9 +82,10 @@ class GenericProduct(AssignPermissionsMixin, models.Model): def __str__(self): return self.product_name - def get_actual_price(self): + def get_actual_price(self, vat_rate=None): + VAT = vat_rate if vat_rate is not None else self.product_vat return round( - self.product_price + (self.product_price * self.product_vat), 2 + self.product_price + (self.product_price * VAT), 2 ) diff --git a/utils/hosting_utils.py b/utils/hosting_utils.py index b3c47e6e..9492e1de 100644 --- a/utils/hosting_utils.py +++ b/utils/hosting_utils.py @@ -5,7 +5,7 @@ import subprocess from oca.pool import WrongIdError from datacenterlight.models import VMPricing -from hosting.models import UserHostingKey, VMDetail +from hosting.models import UserHostingKey, VMDetail, VATRates from opennebula_api.serializers import VirtualMachineSerializer logger = logging.getLogger(__name__) @@ -150,6 +150,18 @@ def ping_ok(host_ipv6): return True +def get_vat_rate_for_country(country): + vat_rate = VATRates.objects.get( + territory_codes=country, start_date__isnull=False, stop_date=None + ) + if vat_rate: + logger.debug("VAT rate for %s is %s" % (country, vat_rate.rate)) + return vat_rate.rate + else: + logger.debug("Did not find VAT rate for %s, returning 0" % country) + return 0 + + class HostingUtils: @staticmethod def clear_items_from_list(from_list, items_list):