Add utils app

This commit is contained in:
M.Ravi 2023-12-06 16:46:30 +05:30
parent 72c3149cc0
commit 6a74124adf
24 changed files with 4235 additions and 0 deletions

0
utils/__init__.py Executable file
View file

7
utils/admin.py Executable file
View file

@ -0,0 +1,7 @@
from django.contrib import admin
from .models import UserBillingAddress
# Register your models here.
admin.site.register(UserBillingAddress)

5
utils/apps.py Executable file
View file

@ -0,0 +1,5 @@
from django.apps import AppConfig
class UtilsConfig(AppConfig):
name = 'utils'

13
utils/backend.py Executable file
View file

@ -0,0 +1,13 @@
import logging
from django.contrib.auth.backends import ModelBackend
logger = logging.getLogger(__name__)
class MyLDAPBackend(ModelBackend):
def authenticate(self, username=None, password=None, **kwargs):
user = super().authenticate(username, password, **kwargs)
if user:
user.create_ldap_account(password)
return user

31
utils/context_processor.py Executable file
View file

@ -0,0 +1,31 @@
from django.conf import settings
def google_analytics(request):
"""
Use the variables returned in this function to
render your Google Analytics tracking code template.
Also check whether the site is a tenant site and create a corresponding
variable to indicate this
"""
host = request.get_host()
ga_prop_id = getattr(settings, 'GOOGLE_ANALYTICS_PROPERTY_IDS', False).get(
host)
which_urlspy = settings.MULTISITE_CMS_URLS.get(host)
if ga_prop_id is None:
# Try checking if we have a www in host, if yes we remove
# that and check in the dict again
if host.startswith('www.'):
ga_prop_id = getattr(settings, 'GOOGLE_ANALYTICS_PROPERTY_IDS',
False).get(host[4:])
which_urlspy = settings.MULTISITE_CMS_URLS.get(host[4:])
return_dict = {}
if not settings.DEBUG and ga_prop_id:
return_dict['GOOGLE_ANALYTICS_PROPERTY_ID'] = ga_prop_id
if which_urlspy:
if which_urlspy.endswith("multi"):
return_dict['IS_TENANT_SITE'] = True
return return_dict

256
utils/fields.py Executable file
View file

@ -0,0 +1,256 @@
from django.utils.translation import gettext as _
from django.db import models
# http://xml.coverpages.org/country3166.html
COUNTRIES = (
('AD', _('Andorra')),
('AE', _('United Arab Emirates')),
('AF', _('Afghanistan')),
('AG', _('Antigua & Barbuda')),
('AI', _('Anguilla')),
('AL', _('Albania')),
('AM', _('Armenia')),
('AN', _('Netherlands Antilles')),
('AO', _('Angola')),
('AQ', _('Antarctica')),
('AR', _('Argentina')),
('AS', _('American Samoa')),
('AT', _('Austria')),
('AU', _('Australia')),
('AW', _('Aruba')),
('AZ', _('Azerbaijan')),
('BA', _('Bosnia and Herzegovina')),
('BB', _('Barbados')),
('BD', _('Bangladesh')),
('BE', _('Belgium')),
('BF', _('Burkina Faso')),
('BG', _('Bulgaria')),
('BH', _('Bahrain')),
('BI', _('Burundi')),
('BJ', _('Benin')),
('BM', _('Bermuda')),
('BN', _('Brunei Darussalam')),
('BO', _('Bolivia')),
('BR', _('Brazil')),
('BS', _('Bahama')),
('BT', _('Bhutan')),
('BV', _('Bouvet Island')),
('BW', _('Botswana')),
('BY', _('Belarus')),
('BZ', _('Belize')),
('CA', _('Canada')),
('CC', _('Cocos (Keeling) Islands')),
('CF', _('Central African Republic')),
('CG', _('Congo')),
('CH', _('Switzerland')),
('CI', _('Ivory Coast')),
('CK', _('Cook Iislands')),
('CL', _('Chile')),
('CM', _('Cameroon')),
('CN', _('China')),
('CO', _('Colombia')),
('CR', _('Costa Rica')),
('CU', _('Cuba')),
('CV', _('Cape Verde')),
('CX', _('Christmas Island')),
('CY', _('Cyprus')),
('CZ', _('Czech Republic')),
('DE', _('Germany')),
('DJ', _('Djibouti')),
('DK', _('Denmark')),
('DM', _('Dominica')),
('DO', _('Dominican Republic')),
('DZ', _('Algeria')),
('EC', _('Ecuador')),
('EE', _('Estonia')),
('EG', _('Egypt')),
('EH', _('Western Sahara')),
('ER', _('Eritrea')),
('ES', _('Spain')),
('ET', _('Ethiopia')),
('FI', _('Finland')),
('FJ', _('Fiji')),
('FK', _('Falkland Islands (Malvinas)')),
('FM', _('Micronesia')),
('FO', _('Faroe Islands')),
('FR', _('France')),
('FX', _('France, Metropolitan')),
('GA', _('Gabon')),
('GB', _('United Kingdom (Great Britain)')),
('GD', _('Grenada')),
('GE', _('Georgia')),
('GF', _('French Guiana')),
('GH', _('Ghana')),
('GI', _('Gibraltar')),
('GL', _('Greenland')),
('GM', _('Gambia')),
('GN', _('Guinea')),
('GP', _('Guadeloupe')),
('GQ', _('Equatorial Guinea')),
('GR', _('Greece')),
('GS', _('South Georgia and the South Sandwich Islands')),
('GT', _('Guatemala')),
('GU', _('Guam')),
('GW', _('Guinea-Bissau')),
('GY', _('Guyana')),
('HK', _('Hong Kong')),
('HM', _('Heard & McDonald Islands')),
('HN', _('Honduras')),
('HR', _('Croatia')),
('HT', _('Haiti')),
('HU', _('Hungary')),
('ID', _('Indonesia')),
('IE', _('Ireland')),
('IL', _('Israel')),
('IN', _('India')),
('IO', _('British Indian Ocean Territory')),
('IQ', _('Iraq')),
('IR', _('Islamic Republic of Iran')),
('IS', _('Iceland')),
('IT', _('Italy')),
('JM', _('Jamaica')),
('JO', _('Jordan')),
('JP', _('Japan')),
('KE', _('Kenya')),
('KG', _('Kyrgyzstan')),
('KH', _('Cambodia')),
('KI', _('Kiribati')),
('KM', _('Comoros')),
('KN', _('St. Kitts and Nevis')),
('KP', _('Korea, Democratic People\'s Republic of')),
('KR', _('Korea, Republic of')),
('KW', _('Kuwait')),
('KY', _('Cayman Islands')),
('KZ', _('Kazakhstan')),
('LA', _('Lao People\'s Democratic Republic')),
('LB', _('Lebanon')),
('LC', _('Saint Lucia')),
('LI', _('Liechtenstein')),
('LK', _('Sri Lanka')),
('LR', _('Liberia')),
('LS', _('Lesotho')),
('LT', _('Lithuania')),
('LU', _('Luxembourg')),
('LV', _('Latvia')),
('LY', _('Libyan Arab Jamahiriya')),
('MA', _('Morocco')),
('MC', _('Monaco')),
('MD', _('Moldova, Republic of')),
('MG', _('Madagascar')),
('MH', _('Marshall Islands')),
('ML', _('Mali')),
('MN', _('Mongolia')),
('MM', _('Myanmar')),
('MO', _('Macau')),
('MP', _('Northern Mariana Islands')),
('MQ', _('Martinique')),
('MR', _('Mauritania')),
('MS', _('Monserrat')),
('MT', _('Malta')),
('MU', _('Mauritius')),
('MV', _('Maldives')),
('MW', _('Malawi')),
('MX', _('Mexico')),
('MY', _('Malaysia')),
('MZ', _('Mozambique')),
('NA', _('Namibia')),
('NC', _('New Caledonia')),
('NE', _('Niger')),
('NF', _('Norfolk Island')),
('NG', _('Nigeria')),
('NI', _('Nicaragua')),
('NL', _('Netherlands')),
('NO', _('Norway')),
('NP', _('Nepal')),
('NR', _('Nauru')),
('NU', _('Niue')),
('NZ', _('New Zealand')),
('OM', _('Oman')),
('PA', _('Panama')),
('PE', _('Peru')),
('PF', _('French Polynesia')),
('PG', _('Papua New Guinea')),
('PH', _('Philippines')),
('PK', _('Pakistan')),
('PL', _('Poland')),
('PM', _('St. Pierre & Miquelon')),
('PN', _('Pitcairn')),
('PR', _('Puerto Rico')),
('PT', _('Portugal')),
('PW', _('Palau')),
('PY', _('Paraguay')),
('QA', _('Qatar')),
('RE', _('Reunion')),
('RO', _('Romania')),
('RU', _('Russian Federation')),
('RW', _('Rwanda')),
('SA', _('Saudi Arabia')),
('SB', _('Solomon Islands')),
('SC', _('Seychelles')),
('SD', _('Sudan')),
('SE', _('Sweden')),
('SG', _('Singapore')),
('SH', _('St. Helena')),
('SI', _('Slovenia')),
('SJ', _('Svalbard & Jan Mayen Islands')),
('SK', _('Slovakia')),
('SL', _('Sierra Leone')),
('SM', _('San Marino')),
('SN', _('Senegal')),
('SO', _('Somalia')),
('SR', _('Suriname')),
('ST', _('Sao Tome & Principe')),
('SV', _('El Salvador')),
('SY', _('Syrian Arab Republic')),
('SZ', _('Swaziland')),
('TC', _('Turks & Caicos Islands')),
('TD', _('Chad')),
('TF', _('French Southern Territories')),
('TG', _('Togo')),
('TH', _('Thailand')),
('TJ', _('Tajikistan')),
('TK', _('Tokelau')),
('TM', _('Turkmenistan')),
('TN', _('Tunisia')),
('TO', _('Tonga')),
('TP', _('East Timor')),
('TR', _('Turkey')),
('TT', _('Trinidad & Tobago')),
('TV', _('Tuvalu')),
('TW', _('Taiwan, Province of China')),
('TZ', _('Tanzania, United Republic of')),
('UA', _('Ukraine')),
('UG', _('Uganda')),
('UM', _('United States Minor Outlying Islands')),
('US', _('United States of America')),
('UY', _('Uruguay')),
('UZ', _('Uzbekistan')),
('VA', _('Vatican City State (Holy See)')),
('VC', _('St. Vincent & the Grenadines')),
('VE', _('Venezuela')),
('VG', _('British Virgin Islands')),
('VI', _('United States Virgin Islands')),
('VN', _('Viet Nam')),
('VU', _('Vanuatu')),
('WF', _('Wallis & Futuna Islands')),
('WS', _('Samoa')),
('YE', _('Yemen')),
('YT', _('Mayotte')),
('YU', _('Yugoslavia')),
('ZA', _('South Africa')),
('ZM', _('Zambia')),
('ZR', _('Zaire')),
('ZW', _('Zimbabwe')),
)
class CountryField(models.CharField):
def __init__(self, *args, **kwargs):
kwargs.setdefault('choices', COUNTRIES)
kwargs.setdefault('default', 'CH')
kwargs.setdefault('max_length', 2)
super(CountryField, self).__init__(*args, **kwargs)
def get_internal_type(self):
return "CharField"

216
utils/forms.py Executable file
View file

@ -0,0 +1,216 @@
from django import forms
from django.contrib.auth import authenticate
from django.core.mail import EmailMultiAlternatives
from django.template.loader import render_to_string
from django.utils.translation import gettext_lazy as _
from membership.models import CustomUser
from .models import ContactMessage, BillingAddress, UserBillingAddress
# from utils.fields import CountryField
class SignupFormMixin(forms.ModelForm):
confirm_password = forms.CharField(widget=forms.PasswordInput())
password = forms.CharField(widget=forms.PasswordInput())
class Meta:
model = CustomUser
fields = ['name', 'email', 'password']
widgets = {
'name': forms.TextInput(
attrs={'placeholder': _('Enter your name or company name')}),
}
def clean_confirm_password(self):
password = self.cleaned_data.get('password')
confirm_password = self.cleaned_data.get('confirm_password')
if not confirm_password == password:
raise forms.ValidationError("Passwords don't match")
return confirm_password
class LoginFormMixin(forms.Form):
email = forms.CharField(widget=forms.EmailInput())
password = forms.CharField(widget=forms.PasswordInput())
class Meta:
fields = ['email', 'password']
def clean(self):
email = self.cleaned_data.get('email')
password = self.cleaned_data.get('password')
is_auth = authenticate(email=email, password=password)
if not is_auth:
raise forms.ValidationError(
_("Your username and/or password were incorrect."))
return self.cleaned_data
def clean_email(self):
email = self.cleaned_data.get('email')
try:
CustomUser.objects.get(email=email)
return email
except CustomUser.DoesNotExist:
raise forms.ValidationError(_("User does not exist"))
class ResendActivationEmailForm(forms.Form):
email = forms.CharField(widget=forms.EmailInput())
class Meta:
fields = ['email']
def clean_email(self):
email = self.cleaned_data.get('email')
try:
c = CustomUser.objects.get(email=email)
if c.validated == 1:
raise forms.ValidationError(
_("The account is already active."))
return email
except CustomUser.DoesNotExist:
raise forms.ValidationError(_("User does not exist"))
class PasswordResetRequestForm(forms.Form):
email = forms.CharField(widget=forms.EmailInput())
class Meta:
fields = ['email']
def clean_email(self):
email = self.cleaned_data.get('email')
try:
CustomUser.objects.get(email=email)
return email
except CustomUser.DoesNotExist:
raise forms.ValidationError(_("User does not exist"))
class SetPasswordForm(forms.Form):
"""
A form that lets a user change set their password without entering the old
password
"""
error_messages = {
'password_mismatch': _("The two password fields didn't match."),
}
new_password1 = forms.CharField(label=_("New password"),
widget=forms.PasswordInput)
new_password2 = forms.CharField(label=_("New password confirmation"),
widget=forms.PasswordInput)
def clean_new_password2(self):
password1 = self.cleaned_data.get('new_password1')
password2 = self.cleaned_data.get('new_password2')
if password1 and password2:
if password1 != password2:
raise forms.ValidationError(
self.error_messages['password_mismatch'],
code='password_mismatch', )
return password2
class EditCreditCardForm(forms.Form):
token = forms.CharField(widget=forms.HiddenInput())
class BillingAddressForm(forms.ModelForm):
token = forms.CharField(widget=forms.HiddenInput(), required=False)
card = forms.CharField(widget=forms.HiddenInput(), required=False)
class Meta:
model = BillingAddress
fields = ['cardholder_name', 'street_address',
'city', 'postal_code', 'country', 'vat_number']
labels = {
'cardholder_name': _('Cardholder Name'),
'street_address': _('Street Address'),
'city': _('City'),
'postal_code': _('Postal Code'),
'Country': _('Country'),
'VAT Number': _('VAT Number')
}
class BillingAddressFormSignup(BillingAddressForm):
name = forms.CharField(label=_('Name'))
email = forms.EmailField(label=_('Email Address'))
field_order = ['name', 'email']
class Meta:
model = BillingAddress
fields = ['name', 'email', 'cardholder_name', 'street_address',
'city', 'postal_code', 'country', 'vat_number']
labels = {
'name': 'Name',
'email': _('Email'),
'cardholder_name': _('Cardholder Name'),
'street_address': _('Street Address'),
'city': _('City'),
'postal_code': _('Postal Code'),
'Country': _('Country'),
'vat_number': _('VAT Number')
}
def clean_email(self):
email = self.cleaned_data.get('email')
try:
CustomUser.objects.get(email=email)
raise forms.ValidationError(
_("The email %(email)s is already registered with us. "
"Please reset your password and access your account.") %
{'email': email}
)
except CustomUser.DoesNotExist:
return email
class UserBillingAddressForm(forms.ModelForm):
user = forms.ModelChoiceField(queryset=CustomUser.objects.all(),
widget=forms.HiddenInput())
class Meta:
model = UserBillingAddress
fields = ['cardholder_name', 'street_address',
'city', 'postal_code', 'country', 'user', 'vat_number']
labels = {
'cardholder_name': _('Cardholder Name'),
'street_address': _('Street Building'),
'city': _('City'),
'postal_code': _('Postal Code'),
'Country': _('Country'),
'vat_number': _('VAT Number'),
}
class ContactUsForm(forms.ModelForm):
error_css_class = 'autofocus'
class Meta:
model = ContactMessage
fields = ['name', 'email', 'phone_number', 'message']
widgets = {
'name': forms.TextInput(attrs={'class': u'form-control'}),
'email': forms.TextInput(attrs={'class': u'form-control'}),
'phone_number': forms.TextInput(attrs={'class': u'form-control'}),
'message': forms.Textarea(attrs={'class': u'form-control'}),
}
labels = {
'name': _('Name'),
'email': _('Email'),
'phone_number': _('Phone number'),
'message': _('Message'),
}
def send_email(self, email_to='info@digitalglarus.ch'):
text_content = render_to_string(
'emails/contact.txt', {'data': self.cleaned_data})
html_content = render_to_string(
'emails/contact.html', {'data': self.cleaned_data})
email = EmailMultiAlternatives('Subject', text_content)
email.attach_alternative(html_content, "text/html")
email.to = [email_to]
email.send()

241
utils/hosting_utils.py Executable file
View file

@ -0,0 +1,241 @@
import decimal
import logging
import subprocess
from django.conf import settings
from oca.pool import WrongIdError
from datacenterlight.models import VMPricing
from hosting.models import UserHostingKey, VMDetail, VATRates
from opennebula_api.serializers import VirtualMachineSerializer
logger = logging.getLogger(__name__)
def get_all_public_keys(customer):
"""
Returns all the public keys of the user
:param customer: The customer whose public keys are needed
:return: A list of public keys
"""
return UserHostingKey.objects.filter(user_id=customer.id).values_list(
"public_key", flat=True).distinct()
def get_or_create_vm_detail(user, manager, vm_id):
"""
Returns VMDetail object related to given vm_id. Creates the object
if it does not exist
:param vm_id: The ID of the VM which should be greater than 0.
:param user: The CustomUser object that owns this VM
:param manager: The OpenNebulaManager object
:return: The VMDetail object. None if vm_id is less than or equal to 0.
Also, for the cases where the VMDetail does not exist and we can not
fetch data about the VM from OpenNebula, the function returns None
"""
if vm_id <= 0:
return None
try:
vm_detail_obj = VMDetail.objects.get(vm_id=vm_id)
except VMDetail.DoesNotExist:
try:
vm_obj = manager.get_vm(vm_id)
except (WrongIdError, ConnectionRefusedError) as e:
logger.error(str(e))
return None
vm = VirtualMachineSerializer(vm_obj).data
vm_detail_obj = VMDetail.objects.create(
user=user, vm_id=vm_id, disk_size=vm['disk_size'],
cores=vm['cores'], memory=vm['memory'],
configuration=vm['configuration'], ipv4=vm['ipv4'],
ipv6=vm['ipv6']
)
return vm_detail_obj
def get_vm_price(cpu, memory, disk_size, hdd_size=0, pricing_name='default'):
"""
A helper function that computes price of a VM from given cpu, ram and
ssd parameters
:param cpu: Number of cores of the VM
:param memory: RAM of the VM
:param disk_size: Disk space of the VM (SSD)
:param hdd_size: The HDD size
:param pricing_name: The pricing name to be used
:return: The price of the VM
"""
try:
pricing = VMPricing.objects.get(name=pricing_name)
except Exception as ex:
logger.error(
"Error getting VMPricing object for {pricing_name}."
"Details: {details}".format(
pricing_name=pricing_name, details=str(ex)
)
)
return None
price = ((decimal.Decimal(cpu) * pricing.cores_unit_price) +
(decimal.Decimal(memory) * pricing.ram_unit_price) +
(decimal.Decimal(disk_size) * pricing.ssd_unit_price) +
(decimal.Decimal(hdd_size) * pricing.hdd_unit_price) +
decimal.Decimal(settings.VM_BASE_PRICE))
cents = decimal.Decimal('.01')
price = price.quantize(cents, decimal.ROUND_HALF_UP)
return round(float(price), 2)
def get_vm_price_for_given_vat(cpu, memory, ssd_size, hdd_size=0,
pricing_name='default', vat_rate=0):
try:
pricing = VMPricing.objects.get(name=pricing_name)
except Exception as ex:
logger.error(
"Error getting VMPricing object for {pricing_name}."
"Details: {details}".format(
pricing_name=pricing_name, details=str(ex)
)
)
return None
price = (
(decimal.Decimal(cpu) * pricing.cores_unit_price) +
(decimal.Decimal(memory) * pricing.ram_unit_price) +
(decimal.Decimal(ssd_size) * pricing.ssd_unit_price) +
(decimal.Decimal(hdd_size) * pricing.hdd_unit_price) +
decimal.Decimal(settings.VM_BASE_PRICE)
)
discount_name = pricing.discount_name
discount_amount = round(float(pricing.discount_amount), 2)
vat = price * decimal.Decimal(vat_rate) * decimal.Decimal(0.01)
vat_percent = vat_rate
cents = decimal.Decimal('.01')
price = price.quantize(cents, decimal.ROUND_HALF_UP)
vat = vat.quantize(cents, decimal.ROUND_HALF_UP)
discount_amount_with_vat = decimal.Decimal(discount_amount) * (1 + decimal.Decimal(vat_rate) * decimal.Decimal(0.01))
discount_amount_with_vat = discount_amount_with_vat.quantize(cents, decimal.ROUND_HALF_UP)
discount = {
'name': discount_name,
'amount': discount_amount,
'amount_with_vat': round(float(discount_amount_with_vat), 2),
'stripe_coupon_id': pricing.stripe_coupon_id
}
return (round(float(price), 2), round(float(vat), 2),
round(float(vat_percent), 2), discount)
def get_vm_price_with_vat(cpu, memory, ssd_size, hdd_size=0,
pricing_name='default'):
"""
A helper function that computes price of a VM from given cpu, ram and
ssd, hdd and the pricing parameters
:param cpu: Number of cores of the VM
:param memory: RAM of the VM
:param ssd_size: Disk space of the VM (SSD)
:param hdd_size: The HDD size
:param pricing_name: The pricing name to be used
:return: The a tuple containing the price of the VM, the VAT and the
VAT percentage
"""
try:
pricing = VMPricing.objects.get(name=pricing_name)
except Exception as ex:
logger.error(
"Error getting VMPricing object for {pricing_name}."
"Details: {details}".format(
pricing_name=pricing_name, details=str(ex)
)
)
return None
price = (
(decimal.Decimal(cpu) * pricing.cores_unit_price) +
(decimal.Decimal(memory) * pricing.ram_unit_price) +
(decimal.Decimal(ssd_size) * pricing.ssd_unit_price) +
(decimal.Decimal(hdd_size) * pricing.hdd_unit_price) +
decimal.Decimal(settings.VM_BASE_PRICE)
)
if pricing.vat_inclusive:
vat = decimal.Decimal(0)
vat_percent = decimal.Decimal(0)
else:
vat = price * pricing.vat_percentage * decimal.Decimal(0.01)
vat_percent = pricing.vat_percentage
cents = decimal.Decimal('.01')
price = price.quantize(cents, decimal.ROUND_HALF_UP)
vat = vat.quantize(cents, decimal.ROUND_HALF_UP)
discount = {
'name': pricing.discount_name,
'amount': round(float(pricing.discount_amount), 2),
'stripe_coupon_id': pricing.stripe_coupon_id
}
return (round(float(price), 2), round(float(vat), 2),
round(float(vat_percent), 2), discount)
def ping_ok(host_ipv6):
"""
A utility method to check if a host responds to ping requests. Note: the
function relies on `ping6` utility of debian to check.
:param host_ipv6 str type parameter that represets the ipv6 of the host to
checked
:return True if the host responds to ping else returns False
"""
try:
subprocess.check_output("ping6 -c 1 " + host_ipv6, shell=True)
except Exception as ex:
logger.debug(host_ipv6 + " not reachable via ping. Error = " + str(ex))
return False
return True
def get_vat_rate_for_country(country):
vat_rate = None
try:
vat_rate = VATRates.objects.get(
territory_codes=country, start_date__isnull=False, stop_date=None
)
logger.debug("VAT rate for %s is %s" % (country, vat_rate.rate))
return vat_rate.rate
except VATRates.DoesNotExist as dne:
logger.debug(str(dne))
logger.debug("Did not find VAT rate for %s, returning 0" % country)
return 0
def get_ip_addresses(vm_id):
try:
vm_detail = VMDetail.objects.get(vm_id=vm_id)
return "%s <br/>%s" % (vm_detail.ipv6, vm_detail.ipv4)
except VMDetail.DoesNotExist as dne:
logger.error(str(dne))
logger.error("VMDetail for %s does not exist" % vm_id)
return "--"
class HostingUtils:
@staticmethod
def clear_items_from_list(from_list, items_list):
"""
A utility function to clear items from a given list.
Useful when deleting items in bulk from session.
e.g.:
HostingUtils.clear_items_from_list(
request.session,
['token', 'billing_address_data', 'card_id',]
)
:param from_list:
:param items_list:
:return:
"""
for var in items_list:
if var in from_list:
del from_list[var]

281
utils/ldap_manager.py Executable file
View file

@ -0,0 +1,281 @@
import base64
import hashlib
import random
import ldap3
import logging
import unicodedata
from django.conf import settings
logger = logging.getLogger(__name__)
class LdapManager:
__instance = None
def __new__(cls):
if LdapManager.__instance is None:
LdapManager.__instance = object.__new__(cls)
return LdapManager.__instance
def __init__(self):
"""
Initialize the LDAP subsystem.
"""
self.rng = random.SystemRandom()
self.server = ldap3.Server(settings.AUTH_LDAP_SERVER)
def get_admin_conn(self):
"""
Return a bound :class:`ldap3.Connection` instance which has write
permissions on the dn in which the user accounts reside.
"""
conn = self.get_conn(user=settings.LDAP_ADMIN_DN,
password=settings.LDAP_ADMIN_PASSWORD,
raise_exceptions=True)
conn.bind()
return conn
def get_conn(self, **kwargs):
"""
Return an unbound :class:`ldap3.Connection` which talks to the configured
LDAP server.
The *kwargs* are passed to the constructor of :class:`ldap3.Connection` and
can be used to set *user*, *password* and other useful arguments.
"""
return ldap3.Connection(self.server, **kwargs)
def _ssha_password(self, password):
"""
Apply the SSHA password hashing scheme to the given *password*.
*password* must be a :class:`bytes` object, containing the utf-8
encoded password.
Return a :class:`bytes` object containing ``ascii``-compatible data
which can be used as LDAP value, e.g. after armoring it once more using
base64 or decoding it to unicode from ``ascii``.
"""
SALT_BYTES = 15
sha1 = hashlib.sha1()
salt = self.rng.getrandbits(SALT_BYTES * 8).to_bytes(SALT_BYTES, "little")
sha1.update(password)
sha1.update(salt)
digest = sha1.digest()
passwd = b"{SSHA}" + base64.b64encode(digest + salt)
return passwd
def create_user(self, user, password, firstname, lastname, email):
conn = self.get_admin_conn()
uidNumber = self._get_max_uid() + 1
logger.debug("uidNumber={uidNumber}".format(uidNumber=uidNumber))
user_exists = True
while user_exists:
user_exists, _ = self.check_user_exists(
"",
'(&(objectClass=inetOrgPerson)(objectClass=posixAccount)'
'(objectClass=top)(uidNumber={uidNumber}))'.format(
uidNumber=uidNumber
)
)
if user_exists:
logger.debug(
"{uid} exists. Trying next.".format(uid=uidNumber)
)
uidNumber += 1
logger.debug("{uid} does not exist. Using it".format(uid=uidNumber))
self._set_max_uid(uidNumber)
try:
uid = user
conn.add("uid={uid},{customer_dn}".format(
uid=uid, customer_dn=settings.LDAP_CUSTOMER_DN
),
["inetOrgPerson", "posixAccount", "ldapPublickey"],
{
"uid": [uid],
"sn": [lastname.encode("utf-8")],
"givenName": [firstname.encode("utf-8")],
"cn": [uid],
"displayName": ["{} {}".format(firstname, lastname).encode("utf-8")],
"uidNumber": [str(uidNumber)],
"gidNumber": [str(settings.LDAP_CUSTOMER_GROUP_ID)],
"loginShell": ["/bin/bash"],
"homeDirectory": ["/home/{}".format(unicodedata.normalize('NFKD', user).encode('ascii','ignore'))],
"mail": email.encode("utf-8"),
"userPassword": [self._ssha_password(
password.encode("utf-8")
)]
}
)
logger.debug('Created user %s %s' % (user.encode('utf-8'), uidNumber))
except Exception as ex:
logger.debug('Could not create user %s' % user.encode('utf-8'))
logger.error("Exception: " + str(ex))
raise Exception(ex)
finally:
conn.unbind()
def change_password(self, uid, new_password):
"""
Changes the password of the user identified by user_dn
:param uid: str The uid that identifies the user
:param new_password: str The new password string
:return: True if password was changed successfully False otherwise
"""
conn = self.get_admin_conn()
# Make sure the user exists first to change his/her details
user_exists, entries = self.check_user_exists(
uid=uid,
search_base=settings.ENTIRE_SEARCH_BASE
)
return_val = False
if user_exists:
try:
return_val = conn.modify(
entries[0].entry_dn,
{
"userpassword": (
ldap3.MODIFY_REPLACE,
[self._ssha_password(new_password.encode("utf-8"))]
)
}
)
except Exception as ex:
logger.error("Exception: " + str(ex))
else:
logger.error("User {} not found".format(uid))
conn.unbind()
return return_val
def change_user_details(self, uid, details):
"""
Updates the user details as per given values in kwargs of the user
identified by user_dn.
Assumes that all attributes passed in kwargs are valid.
:param uid: str The uid that identifies the user
:param details: dict A dictionary containing the new values
:return: True if user details were updated successfully False otherwise
"""
conn = self.get_admin_conn()
# Make sure the user exists first to change his/her details
user_exists, entries = self.check_user_exists(
uid=uid,
search_base=settings.ENTIRE_SEARCH_BASE
)
return_val = False
if user_exists:
details_dict = {k: (ldap3.MODIFY_REPLACE, [v.encode("utf-8")]) for
k, v in details.items()}
try:
return_val = conn.modify(entries[0].entry_dn, details_dict)
msg = "success"
except Exception as ex:
msg = str(ex)
logger.error("Exception: " + msg)
finally:
conn.unbind()
else:
msg = "User {} not found".format(uid)
logger.error(msg)
conn.unbind()
return return_val, msg
def check_user_exists(self, uid, search_filter="", attributes=None,
search_base=settings.LDAP_CUSTOMER_DN, search_attr="uid"):
"""
Check if the user with the given uid exists in the customer group.
:param uid: str representing the user
:param search_filter: str representing the filter condition to find
users. If its empty, the search finds the user with
the given uid.
:param attributes: list A list of str representing all the attributes
to be obtained in the result entries
:param search_base: str
:return: tuple (bool, [ldap3.abstract.entry.Entry ..])
A bool indicating if the user exists
A list of all entries obtained in the search
"""
conn = self.get_admin_conn()
entries = []
try:
result = conn.search(
search_base=search_base,
search_filter=search_filter if len(search_filter) > 0 else
'(uid={uid})'.format(uid=uid),
attributes=attributes
)
entries = conn.entries
finally:
conn.unbind()
return result, entries
def delete_user(self, uid):
"""
Deletes the user with the given uid from ldap
:param uid: str representing the user
:return: True if the delete was successful False otherwise
"""
conn = self.get_admin_conn()
try:
return_val = conn.delete(
("uid={uid}," + settings.LDAP_CUSTOMER_DN).format(uid=uid),
)
msg = "success"
except Exception as ex:
msg = str(ex)
logger.error("Exception: " + msg)
return_val = False
finally:
conn.unbind()
return return_val, msg
def _set_max_uid(self, max_uid):
"""
a utility function to save max_uid value to a file
:param max_uid: an integer representing the max uid
:return:
"""
with open(settings.LDAP_MAX_UID_FILE_PATH, 'w+') as handler:
handler.write(str(max_uid))
def _get_max_uid(self):
"""
A utility function to read the max uid value that was previously set
:return: An integer representing the max uid value that was previously
set
"""
try:
with open(settings.LDAP_MAX_UID_FILE_PATH, 'r+') as handler:
try:
return_value = int(handler.read())
except ValueError as ve:
logger.error(
"Error reading int value from {}. {}"
"Returning default value {} instead".format(
settings.LDAP_MAX_UID_FILE_PATH,
str(ve),
settings.LDAP_DEFAULT_START_UID
)
)
return_value = settings.LDAP_DEFAULT_START_UID
return return_value
except FileNotFoundError as fnfe:
logger.error("File not found : " + str(fnfe))
return_value = settings.LDAP_DEFAULT_START_UID
logger.error("So, returning UID={}".format(return_value))
return return_value

Binary file not shown.

File diff suppressed because it is too large Load diff

67
utils/mailer.py Executable file
View file

@ -0,0 +1,67 @@
import six
from django.core.mail import send_mail
from django.core.mail import EmailMultiAlternatives
from django.template.loader import render_to_string
class BaseEmail(object):
def __init__(self, *args, **kwargs):
self.to = kwargs.get('to')
self.template_name = kwargs.get('template_name')
self.template_path = kwargs.get('template_path')
self.subject = kwargs.get('subject')
self.context = kwargs.get('context', {})
self.template_full_path = '%s%s' % (self.template_path, self.template_name)
text_content = render_to_string('%s.txt' % self.template_full_path, self.context)
html_content = render_to_string('%s.html' % self.template_full_path, self.context)
self.email = EmailMultiAlternatives(self.subject, text_content)
self.email.attach_alternative(html_content, "text/html")
if 'from_address' in kwargs:
self.email.from_email = kwargs.get('from_address')
else:
self.email.from_email = '(ungleich) ungleich Support <info@ungleich.ch>'
self.email.to = [kwargs.get('to', 'info@ungleich.ch')]
def send(self):
self.email.send()
class BaseMailer(object):
def __init__(self):
self._slug = None
self.no_replay_mail = 'info@ungleich.ch'
if not hasattr(self, '_to'):
self._to = None
@property
def slug(self):
return self._slug
@slug.setter
def slug(self, val):
assert isinstance(val, six.string_types), "slug is not string: %r" % val
self._slug = val
@property
def registration(self):
return self.message
@registration.setter
def registration(self, val):
msg = "registration is not dict with fields subject,message"
assert type(val) is dict, msg
assert val.get('subject') and val.get('message'), msg
self._message, self._subject, self._from = (
val.get('message'), val.get('subject'), val.get('from'))
assert isinstance(self.slug, six.string_types), 'slug not set'
def send_mail(self, to=None):
if not to:
to = self._to
if not self.message:
raise NotImplementedError
send_mail(self._subject, self._message, self.no_replay_mail, [to])

View file

@ -0,0 +1,497 @@
"""
This command finds and creates a report for all the usage of css rules in
an app. It aims to optimize existing codebase as well as assist the frontend
developer when designing new components by avoiding unnecessary duplication and
suggesting more/optimal alternatives.
Features:
Currently the command can find out and display:
- Media Breakpoints used in a stylesheet
- Duplicate selectors in a stylesheet
- Unused selectors
Work in progress to enable these features:
- Duplicate style declaration for same selector
- DOM validation
- Finding out dead styles (those that are always cancelled)
- Optimize media declarations
Example:
$ python manage.py optimize_frontend datacenterlight
above command produces a file ../optimize_frontend.html which contains a
report with the above mentioned features
"""
# import csv
import json
import logging
import os
import re
from collections import Counter, OrderedDict
# from itertools import zip_longest
from django import template
from django.conf import settings
from django.contrib.staticfiles import finders
from django.core.management.base import BaseCommand
logger = logging.getLogger(__name__)
RE_PATTERNS = {
'view_html': '[\'\"](.*\.html)',
'html_html': '{% (?:extends|include) [\'\"]?(.*\.html)',
'html_style': '{% static [\'\"]?(.*\.css)',
'css_media': (
'^\s*\@media([^{]+)\{\s*([\s\S]*?})\s*}'
),
'css_selector': (
'^\s*([.#\[:_A-Za-z][^{]*?)\s*'
'\s*{\s*([\s\S]*?)\s*}'
),
'html_class': 'class=[\'\"]([a-zA-Z0-9-_\s]*)',
'html_id': 'id=[\'\"]([a-zA-Z0-9-_]*)'
}
class Command(BaseCommand):
help = (
'Finds unused and duplicate style declarations from the stylesheets '
'used in the templates of each app'
)
requires_system_checks = False
def add_arguments(self, parser):
# positional arguments
parser.add_argument(
'apps', nargs='+', type=str,
help='name of the apps to be optimized'
)
# Named (optional) arguments
parser.add_argument(
'--together',
action='store_true',
help='optimize the apps together'
)
parser.add_argument(
'--css',
action='store_true',
help='optimize only the css rules declared in each stylesheet'
)
def handle(self, *args, **options):
apps_list = options['apps']
report = {}
for app in apps_list:
if options['css']:
report[app] = self.optimize_css(app)
# write report
write_report(report)
def optimize_css(self, app_name):
"""Optimize declarations inside a css stylesheet
Args:
app_name (str): The application name
"""
# get html and css files used in the app
files = get_files(app_name)
# get_selectors_from_css
css_selectors = get_selectors_css(files['style'])
# get_selectors_from_html
html_selectors = get_selectors_html(files['html'])
report = {
'css_dup': get_css_duplication(css_selectors),
'css_unused': get_css_unused(css_selectors, html_selectors)
}
return report
def get_files(app_name):
"""Get all the `html` and `css` files used in an app.
Args:
app_name (str): The application name
Returns:
dict: A dictonary containing Counter of occurence of each
html and css file in `html` and `style` fields respectively.
For example:
{
'html': {'datacenterlight/success.html': 1},
'style': {'datacenterlight/css/bootstrap.min.css': 2}
}
"""
# the view file for the app
app_view = os.path.join(settings.PROJECT_DIR, app_name, 'views.py')
# get template files called from the view
all_html_list = file_match_pattern(app_view, 'view_html')
# list of unique template files
uniq_html_list = list(OrderedDict.fromkeys(all_html_list).keys())
# list of stylesheets
all_style_list = []
file_patterns = ['html_html', 'html_style']
# get html and css files called from within templates
i = 0
while i < len(uniq_html_list):
template_name = uniq_html_list[i]
try:
temp_files = templates_match_pattern(
template_name, file_patterns
)
except template.exceptions.TemplateDoesNotExist as e:
print("template file not found: ", str(e))
all_html_list = [
h for h in all_html_list if h != template_name
]
del uniq_html_list[i]
else:
all_html_list.extend(temp_files[0])
uniq_html_list = list(
OrderedDict.fromkeys(all_html_list).keys()
)
all_style_list.extend(temp_files[1])
i += 1
# counter dict for the html files called from view
result = {
'html': Counter(all_html_list),
'style': Counter(all_style_list)
}
# print(result)
return result
def get_selectors_css(files):
"""Gets the selectors and declarations from a stylesheet.
Args:
files (list): A list of path of stylesheets.
Returns:
dict: A nested dictionary with the structre as
`{'file': {'media-selector': [('selectors',`declarations')]}}`
For example:
{
'datacenterlight/css/landing-page.css':{
'(min-width: 768px)': [
('.lead-right', 'text-align: right;'),
]
}
}
"""
selectors = {}
media_selectors = {}
# get media selectors and other simple declarations
for file in files:
if any(vendor in file for vendor in ['bootstrap', 'font-awesome']):
continue
result = finders.find(file)
if result:
with open(result) as f:
data = f.read()
media_selectors[file] = string_match_pattern(data, 'css_media')
new_data = string_remove_pattern(data, 'css_media')
default_match = string_match_pattern(new_data, 'css_selector')
selectors[file] = {
'default': [
[' '.join(grp.split()) for grp in m] for m in default_match
]
}
# get declarations from media queries
for file, match_list in media_selectors.items():
for match in match_list:
query = match[0]
block_text = ' '.join(match[1].split())
results = string_match_pattern(
block_text, 'css_selector'
)
f_query = ' '.join(query.replace(':', ': ').split())
if f_query in selectors[file]:
selectors[file][f_query].extend(results)
else:
selectors[file][f_query] = results
return selectors
def get_selectors_html(files):
"""Get `class` and `id` used in html files.
Args:
files (list): A list of html files path.
Returns:
dict: a dictonary of all the classes and ids found in the file, in
`class` and `id` field respectively.
"""
selectors = {}
for file in files:
results = templates_match_pattern(file, ['html_class', 'html_id'])
class_dict = {c: 1 for match in results[0] for c in match.split()}
selectors[file] = {
'classes': list(class_dict.keys()),
'ids': results[1],
}
return selectors
def file_match_pattern(file, patterns):
"""Match a regex pattern in a file
Args:
file (str): Complete path of file
patterns (list or str): The pattern(s) to be searched in the file
Returns:
list: A list of all the matches in the file. Each item is a list of
all the captured groups in the pattern. If multiple patterns are given,
the returned list is a list of such lists.
For example:
[('.lead', 'font-size: 18px;'), ('.btn-lg', 'min-width: 180px;')]
"""
with open(file) as f:
data = f.read()
results = string_match_pattern(data, patterns)
return results
def string_match_pattern(data, patterns):
"""Match a regex pattern in a string
Args:
data (str): the string to search for the pattern
patterns (list or str): The pattern(s) to be searched in the file
Returns:
list: A list of all the matches in the string. Each item is a list of
all the captured groups in the pattern. If multiple patterns are given,
the returned list is a list of such lists.
For example:
[('.lead', 'font-size: 18px;'), ('.btn-lg', 'min-width: 180px;')]
"""
if not isinstance(patterns, str):
results = []
for p in patterns:
re_pattern = re.compile(RE_PATTERNS[p], re.MULTILINE)
results.append(re.findall(re_pattern, data))
else:
re_pattern = re.compile(RE_PATTERNS[patterns], re.MULTILINE)
results = re.findall(re_pattern, data)
return results
def string_remove_pattern(data, patterns):
"""Remove a pattern from a string
Args:
data (str): the string to search for the patter
patterns (list or str): The pattern(s) to be removed from the file
Returns:
str: The new string with all instance of matching pattern
removed from it
"""
if not isinstance(patterns, str):
for p in patterns:
re_pattern = re.compile(RE_PATTERNS[p], re.MULTILINE)
data = re.sub(re_pattern, '', data)
else:
re_pattern = re.compile(RE_PATTERNS[patterns], re.MULTILINE)
data = re.sub(re_pattern, '', data)
return data
def templates_match_pattern(template_name, patterns):
"""Match a regex pattern in the first found template file
Args:
file (str): Path of template file
patterns (list or str): The pattern(s) to be searched in the file
Returns:
list: A list of all the matches in the file. Each item is a list of
all the captured groups in the pattern. If multiple patterns are given,
the returned list is a list of such lists.
For example:
[('.lead', 'font-size: 18px;'), ('.btn-lg', 'min-width: 180px;')]
"""
t = template.loader.get_template(template_name)
data = t.template.source
results = string_match_pattern(data, patterns)
return results
def get_css_duplication(css_selectors):
"""Get duplicate selectors from the same stylesheet
Args:
css_selectors (dict): A dictonary containing css selectors from
all the files in the app in the below structure.
`{'file': {'media-selector': [('selectors',`declarations')]}}`
Returns:
dict: A dictonary containing the count of any duplicate selector in
each file.
`{'file': {'media-selector': {'selector': count}}}`
"""
# duplicate css selectors in stylesheets
rule_count = {}
for file, media_selectors in css_selectors.items():
rule_count[file] = {}
for media, rules in media_selectors.items():
rules_dict = Counter([rule[0] for rule in rules])
dup_rules_dict = {k: v for k, v in rules_dict.items() if v > 1}
if dup_rules_dict:
rule_count[file][media] = dup_rules_dict
return rule_count
def get_css_unused(css_selectors, html_selectors):
"""Get selectors from stylesheets that are not used in any of the html
files in which the stylesheet is used.
Args:
css_selectors (dict): A dictonary containing css selectors from
all the files in the app in the below structure.
`{'file': {'media-selector': [('selectors',`declarations')]}}`
html_selectors (dict): A dictonary containing the 'class' and 'id'
declarations from all html files
"""
with open('utils/optimize/test.json', 'w') as f:
json.dump([html_selectors, css_selectors], f, indent=4)
# print(html_selectors, css_selectors)
def write_report(all_reports, filename='frontend'):
"""Write the generated report to a file for re-use
Args;
all_reports (dict): A dictonary of report obtained from different tests
filename (str): An optional suffix for the output file
"""
# full_filename = 'utils/optimize/optimize_' + filename + '.html'
# output_file = os.path.join(
# settings.PROJECT_DIR, full_filename
# )
with open('utils/optimize/op_frontend.json', 'w') as f:
json.dump(all_reports, f, indent=4)
# with open(output_file, 'w', newline='') as f:
# f.write(
# template.loader.render_to_string(
# 'utils/report.html', {'all_reports': all_reports}
# )
# )
# w = csv.writer(f)
# print(zip_longest(*results))
# for r in zip_longest(*results):
# w.writerow(r)
# a list of all the html tags (to be moved in a json file)
html_tags = [
"a",
"abbr",
"address",
"article",
"area",
"aside",
"audio",
"b",
"base",
"bdi",
"bdo",
"blockquote",
"body",
"br",
"button",
"canvas",
"caption",
"cite",
"code",
"col",
"colgroup",
"datalist",
"dd",
"del",
"details",
"dfn",
"div",
"dl",
"dt",
"em",
"embed",
"fieldset",
"figcaption",
"figure",
"footer",
"form",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"head",
"header",
"hgroup",
"hr",
"html",
"i",
"iframe",
"img",
"input",
"ins",
"kbd",
"keygen",
"label",
"legend",
"li",
"link",
"map",
"mark",
"menu",
"meta",
"meter",
"nav",
"noscript",
"object",
"ol",
"optgroup",
"option",
"output",
"p",
"param",
"pre",
"progress",
"q",
"rp",
"rt",
"ruby",
"s",
"samp",
"script",
"section",
"select",
"source",
"small",
"span",
"strong",
"style",
"sub",
"summary",
"sup",
"textarea",
"table",
"tbody",
"td",
"tfoot",
"thead",
"th",
"time",
"title",
"tr",
"u",
"ul",
"var",
"video",
"wbr"
]

17
utils/middleware.py Executable file
View file

@ -0,0 +1,17 @@
#class MultipleProxyMiddleware(object):
# FORWARDED_FOR_FIELDS = [
# 'HTTP_X_FORWARDED_FOR',
# 'HTTP_X_FORWARDED_HOST',
# 'HTTP_X_FORWARDED_SERVER',
# ]
#
# def process_request(self, request):
# """
# Rewrites the proxy headers so that only the most
# recent proxy is used.
# """
# for field in self.FORWARDED_FOR_FIELDS:
# if field in request.META:
# if ',' in request.META[field]:
# parts = request.META[field].split(',')
# request.META[field] = parts[-1].strip()

File diff suppressed because one or more lines are too long

View file

31
utils/mixins.py Executable file
View file

@ -0,0 +1,31 @@
from guardian.shortcuts import assign_perm
class AssignPermissionsMixin(object):
permissions = tuple()
user = None
obj = None
kwargs = dict()
def assign_permissions(self, user):
for permission in self.permissions:
assign_perm(permission, user, self)
# def save(self, *args, **kwargs):
# self.kwargs = kwargs
# self.get_objs()
# create = False
# if not self.pk:
# create = True
# super(AssignPermissionsMixin, self).save(*args, **kwargs)
# if create:
# self.assign_permissions()
# def get_objs(self):
# self.user = self.kwargs.pop('user', None)
# self.obj = self.kwargs.pop('obj', None)
# assert self.user, 'Se necesita el parámetro user para poder asignar los permisos'
# assert self.obj, 'Se necesita el parámetro obj para poder asignar los permisos'

79
utils/models.py Executable file
View file

@ -0,0 +1,79 @@
from django.db import models
from membership.models import CustomUser
from .fields import CountryField
# Create your models here.
class BaseBillingAddress(models.Model):
cardholder_name = models.CharField(max_length=100, default="")
street_address = models.CharField(max_length=100)
city = models.CharField(max_length=50)
postal_code = models.CharField(max_length=50)
country = CountryField()
vat_number = models.CharField(max_length=100, default="", blank=True)
stripe_tax_id = models.CharField(max_length=100, default="", blank=True)
vat_number_validated_on = models.DateTimeField(blank=True, null=True)
vat_validation_status = models.CharField(max_length=25, default="",
blank=True)
class Meta:
abstract = True
class BillingAddress(BaseBillingAddress):
def __str__(self):
if self.vat_number:
return "%s, %s, %s, %s, %s, %s %s %s %s" % (
self.cardholder_name, self.street_address, self.city,
self.postal_code, self.country, self.vat_number,
self.stripe_tax_id, self.vat_number_validated_on,
self.vat_validation_status
)
else:
return "%s, %s, %s, %s, %s" % (
self.cardholder_name, self.street_address, self.city,
self.postal_code, self.country
)
class UserBillingAddress(BaseBillingAddress):
user = models.ForeignKey(CustomUser, related_name='billing_addresses', on_delete=models.CASCADE)
current = models.BooleanField(default=True)
def __str__(self):
if self.vat_number:
return "%s, %s, %s, %s, %s, %s %s %s %s" % (
self.cardholder_name, self.street_address, self.city,
self.postal_code, self.country, self.vat_number,
self.stripe_tax_id, self.vat_number_validated_on,
self.vat_validation_status
)
else:
return "%s, %s, %s, %s, %s" % (
self.cardholder_name, self.street_address, self.city,
self.postal_code, self.country
)
def to_dict(self):
return {
'Cardholder Name': self.cardholder_name,
'Street Address': self.street_address,
'City': self.city,
'Postal Code': self.postal_code,
'Country': self.country,
'VAT Number': self.vat_number
}
class ContactMessage(models.Model):
name = models.CharField(max_length=200)
email = models.EmailField()
phone_number = models.CharField(max_length=200, blank=True)
message = models.TextField()
received_date = models.DateTimeField(auto_now_add=True)
def __str__(self):
return "%s - %s - %s" % (self.name, self.email, self.received_date)

0
utils/optimize/.gitkeep Executable file
View file

573
utils/stripe_utils.py Normal file
View file

@ -0,0 +1,573 @@
import logging
import re
import stripe
from django.conf import settings
from datacenterlight.models import StripePlan
stripe.api_key = settings.STRIPE_API_PRIVATE_KEY
logger = logging.getLogger(__name__)
def handleStripeError(f):
def handleProblems(*args, **kwargs):
response = {
'paid': False,
'response_object': None,
'error': None
}
common_message = "Currently it's not possible to make payments."
try:
response_object = f(*args, **kwargs)
response = {
'response_object': response_object,
'error': None
}
return response
except stripe.error.CardError as e:
# Since it's a decline, stripe.error.CardError will be caught
body = e.json_body
err = body['error']
response.update({'error': err['message']})
logger.error(str(e))
return response
except stripe.error.RateLimitError as e:
logger.error(str(e))
response.update(
{'error': "Too many requests made to the API too quickly"})
return response
except stripe.error.InvalidRequestError as e:
logger.error(str(e))
response.update({'error': str(e._message)})
return response
except stripe.error.AuthenticationError as e:
# Authentication with Stripe's API failed
# (maybe you changed API keys recently)
logger.error(str(e))
response.update({'error': str(e)})
return response
except stripe.error.APIConnectionError as e:
logger.error(str(e))
response.update({'error': str(e)})
return response
except stripe.error.StripeError as e:
# maybe send email
logger.error(str(e))
response.update({'error': str(e)})
return response
except Exception as e:
# maybe send email
logger.error(str(e))
response.update({'error': str(e)})
return response
return handleProblems
class StripeUtils(object):
CURRENCY = 'chf'
INTERVAL = 'month'
SUCCEEDED_STATUS = 'succeeded'
RESOURCE_ALREADY_EXISTS_ERROR_CODE = 'resource_already_exists'
STRIPE_NO_SUCH_PLAN = 'No such plan'
PLAN_EXISTS_ERROR_MSG = 'Plan {} exists already.\nCreating a local StripePlan now.'
PLAN_DOES_NOT_EXIST_ERROR_MSG = 'Plan {} does not exist.'
def __init__(self):
self.stripe = stripe
def update_customer_token(self, customer, token):
customer.source = token
customer.save()
@handleStripeError
def associate_customer_card(self, stripe_customer_id, id_payment_method,
set_as_default=False):
customer = stripe.Customer.retrieve(stripe_customer_id)
stripe.PaymentMethod.attach(
id_payment_method,
customer=stripe_customer_id,
)
if set_as_default:
customer.invoice_settings.default_payment_method = id_payment_method
customer.save()
return True
@handleStripeError
def dissociate_customer_card(self, stripe_customer_id, card_id):
customer = stripe.Customer.retrieve(stripe_customer_id)
if card_id.startswith("pm"):
logger.debug("PaymentMethod %s detached %s" % (card_id,
stripe_customer_id))
pm = stripe.PaymentMethod.retrieve(card_id)
stripe.PaymentMethod.detach(card_id)
pm.delete()
else:
logger.debug("card %s detached %s" % (card_id, stripe_customer_id))
card = customer.sources.retrieve(card_id)
card.delete()
@handleStripeError
def update_customer_card(self, customer_id, token):
customer = stripe.Customer.retrieve(customer_id)
current_card_token = customer.default_source
customer.sources.retrieve(current_card_token).delete()
customer.source = token
customer.save()
credit_card_raw_data = customer.sources.data.pop()
new_card_data = {
'last4': credit_card_raw_data.last4,
'brand': credit_card_raw_data.brand
}
return new_card_data
@handleStripeError
def get_card_details(self, customer_id):
customer = stripe.Customer.retrieve(customer_id)
credit_card_raw_data = customer.sources.data.pop()
card_details = {
'last4': credit_card_raw_data.last4,
'brand': credit_card_raw_data.brand,
'exp_month': credit_card_raw_data.exp_month,
'exp_year': credit_card_raw_data.exp_year,
'fingerprint': credit_card_raw_data.fingerprint,
'card_id': credit_card_raw_data.id
}
return card_details
@handleStripeError
def get_all_invoices(self, customer_id, created_gt):
return_list = []
has_more_invoices = True
starting_after = False
while has_more_invoices:
if starting_after:
invoices = stripe.Invoice.list(
limit=10, customer=customer_id, created={'gt': created_gt},
starting_after=starting_after
)
else:
invoices = stripe.Invoice.list(
limit=10, customer=customer_id, created={'gt': created_gt}
)
has_more_invoices = invoices.has_more
for invoice in invoices.data:
sub_ids = []
for line in invoice.lines.data:
if line.type == 'subscription':
sub_ids.append(line.id)
elif line.type == 'invoiceitem':
sub_ids.append(line.subscription)
else:
sub_ids.append('')
invoice_details = {
'created': invoice.created,
'receipt_number': invoice.receipt_number,
'invoice_number': invoice.number,
'paid_at': invoice.status_transitions.paid_at if invoice.paid else 0,
'period_start': invoice.period_start,
'period_end': invoice.period_end,
'billing_reason': invoice.billing_reason,
'discount': invoice.discount.coupon.amount_off if invoice.discount else 0,
'total': invoice.total,
# to see how many line items we have in this invoice and
# then later check if we have more than 1
'lines_data_count': len(invoice.lines.data) if invoice.lines.data is not None else 0,
'invoice_id': invoice.id,
'lines_meta_data_csv': ','.join(
[line.metadata.VM_ID if hasattr(line.metadata, 'VM_ID') else '' for line in invoice.lines.data]
),
'subscription_ids_csv': ','.join(sub_ids),
'line_items': invoice.lines.data
}
starting_after = invoice.id
return_list.append(invoice_details)
return return_list
@handleStripeError
def get_cards_details_from_token(self, token):
stripe_token = stripe.Token.retrieve(token)
card_details = {
'last4': stripe_token.card.last4,
'brand': stripe_token.card.brand,
'exp_month': stripe_token.card.exp_month,
'exp_year': stripe_token.card.exp_year,
'fingerprint': stripe_token.card.fingerprint,
'card_id': stripe_token.card.id
}
return card_details
@handleStripeError
def get_cards_details_from_payment_method(self, payment_method_id):
payment_method = stripe.PaymentMethod.retrieve(payment_method_id)
# payment_method does not always seem to have a card with id
# if that is the case, fallback to payment_method_id for card_id
card_id = payment_method_id
if hasattr(payment_method.card, 'id'):
card_id = payment_method.card.id
card_details = {
'last4': payment_method.card.last4,
'brand': payment_method.card.brand,
'exp_month': payment_method.card.exp_month,
'exp_year': payment_method.card.exp_year,
'fingerprint': payment_method.card.fingerprint,
'card_id': card_id
}
return card_details
def check_customer(self, stripe_cus_api_id, user, token):
try:
customer = stripe.Customer.retrieve(stripe_cus_api_id)
except stripe.InvalidRequestError:
customer = self.create_customer(token, user.email, user.name)
user.stripecustomer.stripe_id = customer.get(
'response_object').get('id')
user.stripecustomer.save()
if type(customer) is dict:
customer = customer['response_object']
return customer
@handleStripeError
def get_customer(self, stripe_api_cus_id):
customer = stripe.Customer.retrieve(stripe_api_cus_id)
# data = customer.get('response_object')
return customer
@handleStripeError
def create_customer(self, id_payment_method, email, name=None):
if name is None or name.strip() == "":
name = email
customer = self.stripe.Customer.create(
payment_method=id_payment_method,
description=name,
email=email
)
return customer
@handleStripeError
def make_charge(self, amount=None, customer=None):
_amount = float(amount)
amount = int(_amount * 100) # stripe amount unit, in cents
charge = self.stripe.Charge.create(
amount=amount, # in cents
currency=self.CURRENCY,
customer=customer
)
return charge
@handleStripeError
def get_or_create_stripe_plan(self, amount, name, stripe_plan_id,
interval=""):
"""
This function checks if a StripePlan with the given
stripe_plan_id already exists. If it exists then the function
returns this object otherwise it creates a new StripePlan and
returns the new object.
:param amount: The amount in CHF
:param name: The name of the Stripe plan to be created.
:param stripe_plan_id: The id of the Stripe plan to be
created. Use get_stripe_plan_id_string function to
obtain the name of the plan to be created
:param interval: str representing the interval of the Plan
Specifies billing frequency. Either day, week, month or year.
Ref: https://stripe.com/docs/api/plans/create#create_plan-interval
The default is month
:return: The StripePlan object if it exists else creates a
Plan object in Stripe and a local StripePlan and
returns it. Returns None in case of Stripe error
"""
_amount = float(amount)
amount = int(_amount * 100) # stripe amount unit, in cents
stripe_plan_db_obj = None
plan_interval = interval if interval != "" else self.INTERVAL
try:
stripe_plan_db_obj = StripePlan.objects.get(
stripe_plan_id=stripe_plan_id)
except StripePlan.DoesNotExist:
try:
self.stripe.Plan.create(
amount=amount,
interval=plan_interval,
name=name,
currency=self.CURRENCY,
id=stripe_plan_id)
stripe_plan_db_obj = StripePlan.objects.create(
stripe_plan_id=stripe_plan_id)
except stripe.error.InvalidRequestError as e:
logger.error(str(e))
logger.error("error_code = %s" % str(e.__dict__))
if self.RESOURCE_ALREADY_EXISTS_ERROR_CODE in e.error.code:
logger.debug(
self.PLAN_EXISTS_ERROR_MSG.format(stripe_plan_id))
stripe_plan_db_obj, c = StripePlan.objects.get_or_create(
stripe_plan_id=stripe_plan_id)
if c:
logger.debug("Created stripe plan %s" % stripe_plan_id)
else:
logger.debug("Plan %s exists already" % stripe_plan_id)
return stripe_plan_db_obj
@handleStripeError
def delete_stripe_plan(self, stripe_plan_id):
"""
Deletes the Plan in Stripe and also deletes the local db copy
of the plan if it exists
:param stripe_plan_id: The stripe plan id that needs to be
deleted
:return: True if the plan was deleted successfully from
Stripe, False otherwise.
"""
return_value = False
try:
plan = self.stripe.Plan.retrieve(stripe_plan_id)
plan.delete()
return_value = True
StripePlan.objects.filter(
stripe_plan_id=stripe_plan_id).all().delete()
except stripe.error.InvalidRequestError as e:
if self.STRIPE_NO_SUCH_PLAN in str(e):
logger.debug(
self.PLAN_DOES_NOT_EXIST_ERROR_MSG.format(stripe_plan_id))
return return_value
@handleStripeError
def subscribe_customer_to_plan(self, customer, plans, trial_end=None,
coupon="", tax_rates=list(),
default_payment_method=""):
"""
Subscribes the given customer to the list of given plans
:param default_payment_method:
:param tax_rates:
:param coupon:
:param customer: The stripe customer identifier
:param plans: A list of stripe plans.
:param trial_end: An integer representing when the Stripe subscription
is supposed to end
Ref: https://stripe.com/docs/api/python#create_subscription-items
e.g.
plans = [
{
"plan": "dcl-v1-cpu-2-ram-5gb-ssd-10gb",
},
]
:return: The subscription StripeObject
"""
logger.debug("Subscribing %s to plan %s : coupon = %s" % (
customer, str(plans), str(coupon)
))
subscription_result = self.stripe.Subscription.create(
customer=customer, items=plans, trial_end=trial_end,
coupon=coupon,
default_tax_rates=tax_rates,
payment_behavior='allow_incomplete',
default_payment_method=default_payment_method
)
logger.debug("Done subscribing")
return subscription_result
@handleStripeError
def set_subscription_metadata(self, subscription_id, metadata):
subscription = stripe.Subscription.retrieve(subscription_id)
subscription.metadata = metadata
subscription.save()
@handleStripeError
def unsubscribe_customer(self, subscription_id):
"""
Cancels a given subscription
:param subscription_id: The Stripe subscription id string
:return:
"""
sub = stripe.Subscription.retrieve(subscription_id)
return sub.delete()
@handleStripeError
def make_payment(self, customer, amount, token):
charge = self.stripe.Charge.create(
amount=amount, # in cents
currency=self.CURRENCY,
customer=customer
)
return charge
@staticmethod
def get_stripe_plan_id(cpu, ram, ssd, version, app='dcl', hdd=None,
price=None, excl_vat=True):
"""
Returns the Stripe plan id string of the form
`dcl-v1-cpu-2-ram-5gb-ssd-10gb` based on the input parameters
:param cpu: The number of cores
:param ram: The size of the RAM in GB
:param ssd: The size of ssd storage in GB
:param hdd: The size of hdd storage in GB
:param version: The version of the Stripe plans
:param app: The application to which the stripe plan belongs
to. By default it is 'dcl'
:param price: The price for this plan
:return: A string of the form `dcl-v1-cpu-2-ram-5gb-ssd-10gb`
"""
dcl_plan_string = 'cpu-{cpu}-ram-{ram}gb-ssd-{ssd}gb'.format(cpu=cpu,
ram=ram,
ssd=ssd)
if hdd is not None:
dcl_plan_string = '{dcl_plan_string}-hdd-{hdd}gb'.format(
dcl_plan_string=dcl_plan_string, hdd=hdd)
stripe_plan_id_string = '{app}-v{version}-{plan}'.format(
app=app,
version=version,
plan=dcl_plan_string
)
if price is not None:
stripe_plan_id_string = '{}-{}chf'.format(
stripe_plan_id_string,
round(price, 2)
)
if excl_vat:
stripe_plan_id_string = '{}-{}'.format(
stripe_plan_id_string,
"excl_vat"
)
return stripe_plan_id_string
@staticmethod
def get_vm_config_from_stripe_id(stripe_id):
"""
Given a string like "dcl-v1-cpu-2-ram-5gb-ssd-10gb" return different
configuration params as a dict
:param stripe_id|str
:return: dict
"""
pattern = re.compile(r'^dcl-v(\d+)-cpu-(\d+)-ram-(\d+\.?\d*)gb-ssd-(\d+)gb-?(\d*\.?\d*)(chf)?$')
match_res = pattern.match(stripe_id)
if match_res is not None:
price = None
try:
price=match_res.group(5)
except IndexError as ie:
logger.debug("Did not find price in {}".format(stripe_id))
return {
'version': match_res.group(1),
'cores': match_res.group(2),
'ram': match_res.group(3),
'ssd': match_res.group(4),
'price': price
}
@staticmethod
def get_stripe_plan_name(cpu, memory, disk_size, price, excl_vat=True):
"""
Returns the Stripe plan name
:return:
"""
if excl_vat:
return "{cpu} Cores, {memory} GB RAM, {disk_size} GB SSD, " \
"{price} CHF Excl. VAT".format(
cpu=cpu,
memory=memory,
disk_size=disk_size,
price=round(price, 2)
)
else:
return "{cpu} Cores, {memory} GB RAM, {disk_size} GB SSD, " \
"{price} CHF".format(
cpu=cpu,
memory=memory,
disk_size=disk_size,
price=round(price, 2)
)
@handleStripeError
def set_subscription_meta_data(self, subscription_id, meta_data):
"""
Adds VM metadata to a subscription
:param subscription_id: Stripe identifier for the subscription
:param meta_data: A dict of meta data to be added
:return:
"""
subscription = stripe.Subscription.retrieve(subscription_id)
subscription.metadata = meta_data
subscription.save()
@handleStripeError
def get_or_create_tax_id_for_user(self, stripe_customer_id, vat_number,
type="eu_vat", country=""):
tax_ids_list = stripe.Customer.list_tax_ids(
stripe_customer_id,
limit=100,
)
for tax_id_obj in tax_ids_list.data:
if self.compare_vat_numbers(tax_id_obj.value, vat_number):
logger.debug("tax id obj exists already")
return tax_id_obj
else:
logger.debug(
"{val1} is not equal to {val2} or {con1} not same as "
"{con2}".format(val1=tax_id_obj.value, val2=vat_number,
con1=tax_id_obj.country.lower(),
con2=country.lower().strip()))
logger.debug(
"tax id obj does not exist for {val}. Creating a new one".format(
val=vat_number
))
tax_id_obj = stripe.Customer.create_tax_id(
stripe_customer_id,
type=type,
value=vat_number,
)
return tax_id_obj
@handleStripeError
def get_payment_intent(self, amount, customer):
""" Create a stripe PaymentIntent of the given amount and return it
:param amount: the amount of payment_intent
:return:
"""
payment_intent_obj = stripe.PaymentIntent.create(
amount=amount,
currency='chf',
customer=customer,
setup_future_usage='off_session'
)
return payment_intent_obj
@handleStripeError
def get_available_payment_methods(self, customer):
""" Retrieves all payment methods of the given customer
:param customer: StripeCustomer object
:return: a list of available payment methods
"""
return_list = []
if customer is None:
return return_list
cu = stripe.Customer.retrieve(customer.stripe_id)
pms = stripe.PaymentMethod.list(
customer=customer.stripe_id,
type="card",
)
default_source = None
if cu.default_source:
default_source = cu.default_source
else:
default_source = cu.invoice_settings.default_payment_method
for pm in pms.data:
return_list.append({
'last4': pm.card.last4, 'brand': pm.card.brand, 'id': pm.id,
'exp_year': pm.card.exp_year,
'exp_month': '{:02d}'.format(pm.card.exp_month),
'preferred': pm.id == default_source
})
return return_list
def compare_vat_numbers(self, vat1, vat2):
_vat1 = vat1.replace(" ", "").replace(".", "").replace("-","")
_vat2 = vat2.replace(" ", "").replace(".", "").replace("-","")
return True if _vat1 == _vat2 else False

98
utils/tasks.py Normal file
View file

@ -0,0 +1,98 @@
import tempfile
import cdist
from cdist.integration import configure_hosts_simple
from celery.result import AsyncResult
from celery import current_task
from celery.utils.log import get_task_logger
from django.conf import settings
from django.core.mail import EmailMessage
from dynamicweb2.pr_celery import app
logger = get_task_logger(__name__)
@app.task(bind=True, max_retries=settings.CELERY_MAX_RETRIES)
def send_plain_email_task(self, email_data):
"""
This is a generic celery task to be used for sending emails.
A celery wrapper task for EmailMessage
:param self:
:param email_data: A dict of all needed email headers
:return:
"""
email = EmailMessage(**email_data)
email.send()
@app.task(bind=True, max_retries=settings.CELERY_MAX_RETRIES)
def save_ssh_key(self, hosts, keys):
"""
Saves ssh key into the VMs of a user using cdist
:param hosts: A list of hosts to be configured
:param keys: A list of keys to be added. A key should be dict of the
form {
'value': 'sha-.....', # public key as string
'state': True # whether key is to be added or
} # removed
"""
logger.debug(
"Running save_ssh_key on {}".format(current_task.request.hostname))
logger.debug("""Running save_ssh_key task for
Hosts: {hosts_str}
Keys: {keys_str}""".format(hosts_str=", ".join(hosts),
keys_str=", ".join([
"{value}->{state}".format(
value=key.get('value'),
state=str(
key.get('state')))
for key in keys]))
)
return_value = True
with tempfile.NamedTemporaryFile(delete=True) as tmp_manifest:
# Generate manifest to be used for configuring the hosts
lines_list = [
' --key "{key}" --state {state} \\\n'.format(
key=key['value'],
state='present' if key['state'] else 'absent'
).encode('utf-8')
for key in keys]
lines_list.insert(0, b'__ssh_authorized_keys root \\\n')
tmp_manifest.writelines(lines_list)
tmp_manifest.flush()
try:
configure_hosts_simple(hosts,
tmp_manifest.name,
verbose=cdist.argparse.VERBOSE_TRACE)
except Exception as cdist_exception:
logger.error(cdist_exception)
return_value = False
email_data = {
'subject': "celery save_ssh_key error - task id {0}".format(
self.request.id.__str__()),
'from_email': current_task.request.hostname,
'to': settings.DCL_ERROR_EMAILS_TO_LIST,
'body': "Task Id: {0}\nResult: {1}\nTraceback: {2}".format(
self.request.id.__str__(), False, str(cdist_exception)),
}
send_plain_email_task(email_data)
return return_value
@app.task
def save_ssh_key_error_handler(uuid):
result = AsyncResult(uuid)
exc = result.get(propagate=False)
logger.error('Task {0} raised exception: {1!r}\n{2!r}'.format(
uuid, exc, result.traceback))
email_data = {
'subject': "[celery error] Save SSH key error {0}".format(uuid),
'from_email': current_task.request.hostname,
'to': settings.DCL_ERROR_EMAILS_TO_LIST,
'body': "Task Id: {0}\nResult: {1}\nTraceback: {2}".format(
uuid, exc, result.traceback),
}
send_plain_email_task(email_data)

96
utils/test_forms.py Executable file
View file

@ -0,0 +1,96 @@
from django.test import TestCase
from .forms import ContactUsForm, BillingAddressForm, PasswordResetRequestForm,\
SetPasswordForm
from model_mommy import mommy
class PasswordResetRequestFormTest(TestCase):
def setUp(self):
self.user = mommy.make('CustomUser')
self.completed_data = {
'email': self.user.email,
}
self.incorrect_data = {
'email': 'test',
}
def test_valid_form(self):
form = PasswordResetRequestForm(data=self.completed_data)
self.assertTrue(form.is_valid())
def test_invalid_form(self):
form = PasswordResetRequestForm(data=self.incorrect_data)
self.assertFalse(form.is_valid())
class SetPasswordFormTest(TestCase):
def setUp(self):
# self.user = mommy.make('CustomUser')
self.completed_data = {
'new_password1': 'new_password',
'new_password2': 'new_password',
}
self.incorrect_data = {
'email': 'test',
}
def test_valid_form(self):
form = SetPasswordForm(data=self.completed_data)
self.assertTrue(form.is_valid())
def test_invalid_form(self):
form = SetPasswordForm(data=self.incorrect_data)
self.assertFalse(form.is_valid())
class ContactUsFormTest(TestCase):
def setUp(self):
self.completed_data = {
'name': 'test',
'email': 'test@gmail.com',
'phone_number': '32123123123123',
'message': 'This is a message',
}
self.incompleted_data = {
'name': 'test',
}
def test_valid_form(self):
form = ContactUsForm(data=self.completed_data)
self.assertTrue(form.is_valid())
def test_invalid_form(self):
form = ContactUsForm(data=self.incompleted_data)
self.assertFalse(form.is_valid())
class BillingAddressFormTest(TestCase):
def setUp(self):
self.completed_data = {
'cardholder_name': 'test',
'street_address': 'street name',
'city': 'MyCity',
'postal_code': '32123123123123',
'country': 'VE',
'token': 'a23kfmslwxhkwis'
}
self.incompleted_data = {
'street_address': 'test',
}
def test_valid_form(self):
form = BillingAddressForm(data=self.completed_data)
self.assertTrue(form.is_valid())
def test_invalid_form(self):
form = BillingAddressForm(data=self.incompleted_data)
self.assertFalse(form.is_valid())

306
utils/tests.py Executable file
View file

@ -0,0 +1,306 @@
import uuid
from time import sleep
from unittest.mock import patch
import stripe
from celery.result import AsyncResult
from django.conf import settings
from django.http.request import HttpRequest
from django.test import Client
from django.test import TestCase, override_settings
from unittest import skipIf
from model_mommy import mommy
from datacenterlight.models import StripePlan
from membership.models import StripeCustomer
from utils.stripe_utils import StripeUtils
from .tasks import save_ssh_key
class BaseTestCase(TestCase):
"""
Base class to initialize the test cases
"""
def setUp(self):
# Password
self.dummy_password = 'test_password'
# Users
self.customer, self.another_customer = mommy.make(
'membership.CustomUser', validated=1, _quantity=2
)
self.customer.set_password(self.dummy_password)
self.customer.save()
self.another_customer.set_password(self.dummy_password)
self.another_customer.save()
# Stripe mocked data
self.stripe_mocked_customer = self.customer_stripe_mocked_data()
# Clients
self.customer_client = self.get_client(self.customer)
self.another_customer_client = self.get_client(self.another_customer)
# Request Object
self.request = HttpRequest()
self.request.META['SERVER_NAME'] = 'ungleich.ch'
self.request.META['SERVER_PORT'] = '80'
def get_client(self, user):
"""
Authenticate a user and return the client
"""
client = Client()
client.login(email=user.email, password=self.dummy_password)
return client
def customer_stripe_mocked_data(self):
return {
"id": "cus_8R1y9UWaIIjZqr",
"object": "customer",
"currency": "usd",
"default_source": "card_18A9up2eZvKYlo2Cq2RJMGeF",
"email": "vmedixtodd+1@gmail.com",
"livemode": False,
"metadata": {
},
"shipping": None,
"sources": {
"object": "list",
"data": [{
"id": "card_18A9up2eZvKYlo2Cq2RJMGeF",
"object": "card",
"brand": "Visa",
"country": "US",
"customer": "cus_8R1y9UWaIIjZqr",
"cvc_check": "pass",
"dynamic_last4": None,
"exp_month": 12,
"exp_year": 2018,
"funding": "credit",
"last4": "4242",
}]
}
}
def setup_view(self, view, *args, **kwargs):
"""Mimic as_view() returned callable, but returns view instance.
args and kwargs are the same you would pass to ``reverse()``
"""
view.request = self.request
view.args = args
view.kwargs = kwargs
view.config = None
return view
@skipIf(settings.STRIPE_API_PRIVATE_KEY_TEST is None or
settings.STRIPE_API_PRIVATE_KEY_TEST is "",
"""Skip because STRIPE_API_PRIVATE_KEY_TEST is not set""")
class TestStripeCustomerDescription(TestCase):
"""
A class to test setting the description field of the stripe customer
https://stripe.com/docs/api#metadata
"""
def setUp(self):
self.customer_password = 'test_password'
self.customer_email = 'test@ungleich.ch'
self.customer_name = "Monty Python"
self.customer = mommy.make('membership.CustomUser')
self.customer.set_password(self.customer_password)
self.customer.email = self.customer_email
self.customer.save()
self.stripe_utils = StripeUtils()
stripe.api_key = settings.STRIPE_API_PRIVATE_KEY_TEST
self.token = stripe.Token.create(
card={
"number": '4111111111111111',
"exp_month": 12,
"exp_year": 2022,
"cvc": '123'
},
)
self.failed_token = stripe.Token.create(
card={
"number": '4000000000000341',
"exp_month": 12,
"exp_year": 2022,
"cvc": '123'
},
)
def test_creating_stripe_customer(self):
stripe_data = self.stripe_utils.create_customer(self.token.id,
self.customer.email,
self.customer_name)
self.assertEqual(stripe_data.get('error'), None)
customer_data = stripe_data.get('response_object')
self.assertEqual(customer_data.description, self.customer_name)
@skipIf(settings.STRIPE_API_PRIVATE_KEY_TEST == "" or
settings.TEST_MANAGE_SSH_KEY_HOST == "",
"""Skipping test_save_ssh_key_add because either host
or public key were not specified or were empty""")
class StripePlanTestCase(TestStripeCustomerDescription):
"""
A class to test Stripe plans
"""
def test_get_stripe_plan_id_string(self):
plan_id_string = StripeUtils.get_stripe_plan_id(cpu=2, ram=20, ssd=100,
version=1, app='dcl')
self.assertEqual(plan_id_string, 'dcl-v1-cpu-2-ram-20gb-ssd-100gb')
plan_id_string = StripeUtils.get_stripe_plan_id(cpu=2, ram=20, ssd=100,
version=1, app='dcl',
hdd=200)
self.assertEqual(plan_id_string,
'dcl-v1-cpu-2-ram-20gb-ssd-100gb-hdd-200gb')
def test_get_or_create_plan(self):
stripe_plan = self.stripe_utils.get_or_create_stripe_plan(2000,
"test plan 1",
stripe_plan_id='test-plan-1')
self.assertIsNone(stripe_plan.get('error'))
self.assertIsInstance(stripe_plan.get('response_object'), StripePlan)
@skipIf(settings.TEST_MANAGE_SSH_KEY_PUBKEY == "" or
settings.TEST_MANAGE_SSH_KEY_HOST == "",
"""Skipping test_save_ssh_key_add because either host
or public key were not specified or were empty""")
@patch('utils.stripe_utils.logger')
def test_create_duplicate_plans_error_handling(self, mock_logger):
"""
Test details:
1. Create a test plan in Stripe with a particular id
2. Try to recreate the plan with the same id
3. This creates a Stripe error, the code should be able to handle the error
:param mock_logger:
:return:
"""
unique_id = str(uuid.uuid4().hex)
new_plan_id_str = 'test-plan-{}'.format(unique_id)
stripe_plan = self.stripe_utils.get_or_create_stripe_plan(2000,
"test plan {}".format(
unique_id),
stripe_plan_id=new_plan_id_str)
self.assertIsInstance(stripe_plan.get('response_object'), StripePlan)
self.assertEqual(stripe_plan.get('response_object').stripe_plan_id,
new_plan_id_str)
# Test creating the same plan again and expect the PLAN_EXISTS_ERROR_MSG
# We first delete the local Stripe Plan, so that the code tries to create a new plan in Stripe
StripePlan.objects.filter(
stripe_plan_id=new_plan_id_str).all().delete()
stripe_plan_1 = self.stripe_utils.get_or_create_stripe_plan(2000,
"test plan {}".format(
unique_id),
stripe_plan_id=new_plan_id_str)
mock_logger.debug.assert_called_with(
self.stripe_utils.PLAN_EXISTS_ERROR_MSG.format(new_plan_id_str))
self.assertIsInstance(stripe_plan_1.get('response_object'), StripePlan)
self.assertEqual(stripe_plan_1.get('response_object').stripe_plan_id,
new_plan_id_str)
# Delete the test stripe plan that we just created
delete_result = self.stripe_utils.delete_stripe_plan(new_plan_id_str)
self.assertIsInstance(delete_result, dict)
self.assertEqual(delete_result.get('response_object'), True)
@patch('utils.stripe_utils.logger')
def test_delete_unexisting_plan_should_fail(self, mock_logger):
plan_id = 'crazy-plan-id-that-does-not-exist'
result = self.stripe_utils.delete_stripe_plan(plan_id)
self.assertIsInstance(result, dict)
self.assertEqual(result.get('response_object'), False)
mock_logger.debug.assert_called_with(
self.stripe_utils.PLAN_DOES_NOT_EXIST_ERROR_MSG.format(plan_id))
def test_subscribe_customer_to_plan(self):
stripe_plan = self.stripe_utils.get_or_create_stripe_plan(2000,
"test plan 1",
stripe_plan_id='test-plan-1')
stripe_customer = StripeCustomer.get_or_create(
email=self.customer_email,
token=self.token)
result = self.stripe_utils.subscribe_customer_to_plan(
stripe_customer.stripe_id,
[{"plan": stripe_plan.get(
'response_object').stripe_plan_id}])
self.assertIsInstance(result.get('response_object'),
stripe.Subscription)
self.assertIsNone(result.get('error'))
self.assertEqual(result.get('response_object').get('status'), 'active')
def test_subscribe_customer_to_plan_failed_payment(self):
stripe_plan = self.stripe_utils.get_or_create_stripe_plan(2000,
"test plan 1",
stripe_plan_id='test-plan-1')
stripe_customer = StripeCustomer.get_or_create(
email=self.customer_email,
token=self.failed_token)
result = self.stripe_utils.subscribe_customer_to_plan(
stripe_customer.stripe_id,
[{"plan": stripe_plan.get(
'response_object').stripe_plan_id}])
self.assertIsNone(result.get('response_object'), None)
self.assertIsNotNone(result.get('error'))
class SaveSSHKeyTestCase(TestCase):
"""
A test case to test the celery save_ssh_key task
"""
@override_settings(
task_eager_propagates=True,
task_always_eager=True,
)
def setUp(self):
self.public_key = settings.TEST_MANAGE_SSH_KEY_PUBKEY
self.hosts = settings.TEST_MANAGE_SSH_KEY_HOST
@skipIf(settings.TEST_MANAGE_SSH_KEY_PUBKEY is "" or
settings.TEST_MANAGE_SSH_KEY_PUBKEY is None or
settings.TEST_MANAGE_SSH_KEY_HOST is "" or
settings.TEST_MANAGE_SSH_KEY_HOST is None,
"""Skipping test_save_ssh_key_add because either host
or public key were not specified or were empty""")
def test_save_ssh_key_add(self):
async_task = save_ssh_key.delay([self.hosts],
[{'value': self.public_key,
'state': True}])
save_ssh_key_result = None
for i in range(0, 10):
sleep(5)
res = AsyncResult(async_task.task_id)
if type(res.result) is bool:
save_ssh_key_result = res.result
break
self.assertIsNotNone(save_ssh_key, "save_ssh_key_result is None")
self.assertTrue(save_ssh_key_result, "save_ssh_key_result is False")
@skipIf(settings.TEST_MANAGE_SSH_KEY_PUBKEY is None or
settings.TEST_MANAGE_SSH_KEY_PUBKEY == "" or
settings.TEST_MANAGE_SSH_KEY_HOST is None or
settings.TEST_MANAGE_SSH_KEY_HOST is "",
"""Skipping test_save_ssh_key_add because either host
or public key were not specified or were empty""")
def test_save_ssh_key_remove(self):
async_task = save_ssh_key.delay([self.hosts],
[{'value': self.public_key,
'state': False}])
save_ssh_key_result = None
for i in range(0, 10):
sleep(5)
res = AsyncResult(async_task.task_id)
if type(res.result) is bool:
save_ssh_key_result = res.result
break
self.assertIsNotNone(save_ssh_key, "save_ssh_key_result is None")
self.assertTrue(save_ssh_key_result, "save_ssh_key_result is False")

269
utils/views.py Executable file
View file

@ -0,0 +1,269 @@
import uuid
from django.conf import settings
from django.contrib import messages
from django.contrib.auth import authenticate, login
from django.contrib.auth.tokens import default_token_generator
from django.core.files.base import ContentFile
from django.urls import reverse_lazy
from django.http import HttpResponseRedirect
from django.shortcuts import render
from django.utils.encoding import force_bytes
from django.utils.http import urlsafe_base64_encode, urlsafe_base64_decode
from django.utils.translation import gettext_lazy as _
from django.views.decorators.cache import cache_control
from django.views.generic import FormView, CreateView
from datacenterlight.utils import get_cms_integration
from hosting.forms import UserHostingKeyForm
from hosting.models import UserHostingKey
from membership.models import CustomUser
from opennebula_api.opennebula_manager import OpenNebulaManager
from utils.hosting_utils import get_all_public_keys
from .forms import SetPasswordForm
from .mailer import BaseEmail
class SignupViewMixin(CreateView):
model = CustomUser
success_url = None
def get_success_url(self):
next_url = self.request.POST.get('next') if self.request.POST.get(
'next') \
else self.success_url
return next_url
def form_valid(self, form):
name = form.cleaned_data.get('name')
email = form.cleaned_data.get('email')
password = form.cleaned_data.get('password')
CustomUser.register(name, password, email)
auth_user = authenticate(email=email, password=password)
login(self.request, auth_user)
return HttpResponseRedirect(self.get_success_url())
class LoginViewMixin(FormView):
success_url = None
def get_success_url(self):
next_url = self.request.POST.get('next', self.success_url)
if not next_url:
return self.success_url
return next_url
def form_valid(self, form):
email = form.cleaned_data.get('email')
password = form.cleaned_data.get('password')
auth_user = authenticate(email=email, password=password)
if auth_user:
login(self.request, auth_user)
return HttpResponseRedirect(self.get_success_url())
return HttpResponseRedirect(self.get_success_url())
@cache_control(no_cache=True, must_revalidate=True, no_store=True)
def get(self, request, *args, **kwargs):
if self.request.user.is_authenticated():
return HttpResponseRedirect(self.get_success_url())
return super(LoginViewMixin, self).get(request, *args, **kwargs)
class ResendActivationLinkViewMixin(FormView):
success_message = _(
"An email with the activation link has been sent to you")
def generate_email_context(self, user):
context = {
'base_url': "{0}://{1}".format(self.request.scheme,
self.request.get_host()),
'activation_link': reverse_lazy(
'hosting:validate',
kwargs={'validate_slug': user.validation_slug}
),
'dcl_text': settings.DCL_TEXT,
}
return context
def form_valid(self, form):
email = form.cleaned_data.get('email')
user = CustomUser.objects.get(email=email)
messages.add_message(self.request, messages.SUCCESS,
self.success_message)
context = self.generate_email_context(user)
email_data = {
'subject': '{dcl_text} {account_activation}'.format(
dcl_text=settings.DCL_TEXT,
account_activation=_('Account Activation')
),
'to': email,
'context': context,
'template_name': self.email_template_name,
'template_path': self.email_template_path,
'from_address': settings.DCL_SUPPORT_FROM_ADDRESS
}
email = BaseEmail(**email_data)
email.send()
return HttpResponseRedirect(self.get_success_url())
class PasswordResetViewMixin(FormView):
success_message = _(
"The link to reset your password has been sent to your email")
site = ''
def test_generate_email_context(self, user):
context = {
'user': user,
'token': default_token_generator.make_token(user),
'uid': urlsafe_base64_encode(force_bytes(user.pk)),
'site_name': 'ungleich' if self.site != 'dcl' else settings.DCL_TEXT,
'base_url': "{0}://{1}".format(self.request.scheme,
self.request.get_host())
}
return context
def form_valid(self, form):
email = form.cleaned_data.get('email')
user = CustomUser.objects.get(email=email)
messages.add_message(self.request, messages.SUCCESS,
self.success_message)
context = self.test_generate_email_context(user)
email_data = {
'subject': _('Password Reset'),
'to': email,
'context': context,
'template_name': 'password_reset_email',
'template_path': self.template_email_path
}
if self.site == 'dcl':
email_data['from_address'] = settings.DCL_SUPPORT_FROM_ADDRESS
email = BaseEmail(**email_data)
email.send()
return HttpResponseRedirect(self.get_success_url())
class PasswordResetConfirmViewMixin(FormView):
form_class = SetPasswordForm
def post(self, request, uidb64=None, token=None, *arg, **kwargs):
try:
uid = urlsafe_base64_decode(uidb64)
user = CustomUser.objects.get(pk=uid)
except (TypeError, ValueError, OverflowError, CustomUser.DoesNotExist):
user = None
form = self.form_class(request.POST)
if user is not None and default_token_generator.check_token(user,
token):
if form.is_valid():
new_password = form.cleaned_data['new_password2']
user.set_password(new_password)
user.save()
messages.success(request, _('Password has been reset.'))
return self.form_valid(form)
else:
messages.error(request,
_('Password reset has not been successful.'))
form.add_error(None,
_('Password reset has not been successful.'))
return self.form_invalid(form)
else:
messages.error(request,
_('The reset password link is no longer valid.'))
form.add_error(None,
_('The reset password link is no longer valid.'))
return self.form_invalid(form)
class SSHKeyCreateView(FormView):
form_class = UserHostingKeyForm
model = UserHostingKey
template_name = 'hosting/user_key.html'
login_url = reverse_lazy('hosting:login')
context_object_name = "virtual_machine"
success_url = reverse_lazy('hosting:ssh_keys')
def get_form_kwargs(self):
kwargs = super(SSHKeyCreateView, self).get_form_kwargs()
kwargs.update({'request': self.request})
return kwargs
def form_valid(self, form):
form.save()
if settings.DCL_SSH_KEY_NAME_PREFIX in form.instance.name:
content = ContentFile(form.cleaned_data.get('private_key'))
filename = form.cleaned_data.get(
'name') + '_' + str(uuid.uuid4())[:8] + '_private.pem'
form.instance.private_key.save(filename, content)
context = self.get_context_data()
next_url = self.request.session.get(
'next',
reverse_lazy('hosting:create_virtual_machine')
)
if 'next' in self.request.session:
context.update({
'next_url': next_url
})
del (self.request.session['next'])
if form.cleaned_data.get('private_key'):
context.update({
'private_key': form.cleaned_data.get('private_key'),
'key_name': form.cleaned_data.get('name'),
'form': UserHostingKeyForm(request=self.request),
})
if self.request.user.is_authenticated():
owner = self.request.user
manager = OpenNebulaManager(
email=owner.username,
password=owner.password
)
keys_to_save = get_all_public_keys(self.request.user)
manager.save_key_in_opennebula_user('\n'.join(keys_to_save))
else:
self.request.session["new_user_hosting_key_id"] = form.instance.id
return HttpResponseRedirect(self.success_url)
def post(self, request, *args, **kwargs):
form = self.get_form()
required = 'add_ssh' in self.request.POST
form.fields['name'].required = required
form.fields['public_key'].required = required
if form.is_valid():
return self.form_valid(form)
else:
return self.form_invalid(form)
class AskSSHKeyView(SSHKeyCreateView):
form_class = UserHostingKeyForm
template_name = "datacenterlight/add_ssh_key.html"
success_url = reverse_lazy('datacenterlight:order_confirmation')
context_object_name = "dcl_vm_buy_add_ssh_key"
@cache_control(no_cache=True, must_revalidate=True, no_store=True)
def get(self, request, *args, **kwargs):
context = {
'site_url': reverse_lazy('datacenterlight:index'),
'cms_integration': get_cms_integration('default'),
'form': UserHostingKeyForm(request=self.request),
'keys': get_all_public_keys(self.request.user)
}
return render(request, self.template_name, context)
def post(self, request, *args, **kwargs):
self.success_url = self.request.session.get("order_confirm_url")
return super(AskSSHKeyView, self).post(self, request, *args, **kwargs)