Merge branch 'master' of code.ungleich.ch:uncloud/uncloud

This commit is contained in:
Nico Schottelius 2020-04-11 21:37:50 +02:00
commit bab59b1879
21 changed files with 733 additions and 241 deletions

View file

@ -1,8 +1,22 @@
image: python:3 stages:
- lint
- test
before_script: run-tests:
- python setup.py install stage: test
image: fedora:latest
python_tests: services:
script: - postgres:latest
- python -m unittest -v test/test_mac_local.py variables:
DATABASE_HOST: postgres
DATABASE_USER: postgres
POSTGRES_HOST_AUTH_METHOD: trust
coverage: /^TOTAL.+?(\d+\%)$/
before_script:
- dnf install -y python3-devel python3-pip libpq-devel openldap-devel gcc
script:
- cd uncloud_django_based/uncloud
- pip install -r requirements.txt
- cp uncloud/secrets_sample.py uncloud/secrets.py
- coverage run --source='.' ./manage.py test
- coverage report

View file

@ -1,5 +1,21 @@
* What is a remote uncloud client? * What is a remote uncloud client?
** Systems that configure themselves for the use with uncloud ** Systems that configure themselves for the use with uncloud
** Examples are VMHosts, VPN Servers, etc. ** Examples are VMHosts, VPN Servers, cdist control server, etc.
* Which access do these clients need? * Which access do these clients need?
** They need read / write access to the database ** They need read / write access to the database
* Possible methods
** Overview
| | pros | cons |
| SSL based | Once setup, can access all django parts natively, locally | X.509 infrastructure |
| SSH -L tunnel | All nodes can use [::1]:5432 | SSH setup can be fragile |
| ssh djangohost manage.py | All DB ops locally | Code is only executed on django host |
| https + token | Rest alike / consistent access | Code is only executed on django host |
** remote vs. local Django code execution
- If manage.py is executed locally (= on the client), it can
check/modify local configs
- However local execution requires a pyvenv + packages + db access
- Local execution also *could* make use of postgresql notify for
triggering actions (which is quite neat)
- Remote execution (= on the primary django host) can acess the db
via unix socket
- However remote execution cannot check local state

View file

@ -1,5 +1,4 @@
# Live/test key from stripe from django.core.management.utils import get_random_secret_key
STRIPE_KEY = ''
# XML-RPC interface of opennebula # XML-RPC interface of opennebula
OPENNEBULA_URL = 'https://opennebula.ungleich.ch:2634/RPC2' OPENNEBULA_URL = 'https://opennebula.ungleich.ch:2634/RPC2'
@ -15,6 +14,8 @@ LDAP_ADMIN_PASSWORD=""
LDAP_SERVER_URI = "" LDAP_SERVER_URI = ""
# Stripe (Credit Card payments) # Stripe (Credit Card payments)
STRIPE_API_key="" STRIPE_KEY=""
STRIPE_PUBLIC_KEY=""
SECRET_KEY="dx$iqt=lc&yrp^!z5$ay^%g5lhx1y3bcu=jg(jx0yj0ogkfqvf" # The django secret key
SECRET_KEY=get_random_secret_key()

View file

@ -27,8 +27,9 @@ except ModuleNotFoundError:
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.postgresql', 'ENGINE': 'django.db.backends.postgresql',
'HOST': '::1', # connecting via tcp, v6, to allow ssh forwarding to work
'NAME': uncloud.secrets.POSTGRESQL_DB_NAME, 'NAME': uncloud.secrets.POSTGRESQL_DB_NAME,
'HOST': os.environ.get('DATABASE_HOST', '::1'),
'USER': os.environ.get('DATABASE_USER', 'postgres'),
} }
} }

View file

@ -54,7 +54,6 @@ router.register(r'payment-method', payviews.PaymentMethodViewSet, basename='paym
router.register(r'bill', payviews.BillViewSet, basename='bill') router.register(r'bill', payviews.BillViewSet, basename='bill')
router.register(r'order', payviews.OrderViewSet, basename='order') router.register(r'order', payviews.OrderViewSet, basename='order')
router.register(r'payment', payviews.PaymentViewSet, basename='payment') router.register(r'payment', payviews.PaymentViewSet, basename='payment')
router.register(r'payment-method', payviews.PaymentMethodViewSet, basename='payment-methods')
# admin/staff urls # admin/staff urls

View file

@ -0,0 +1,24 @@
# Generated by Django 3.0.5 on 2020-04-09 12:25
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('uncloud_net', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='vpnnetworkreservation',
name='status',
field=models.CharField(choices=[('used', 'used'), ('free', 'free')], default='used', max_length=256),
),
migrations.AlterField(
model_name='vpnnetwork',
name='network',
field=models.ForeignKey(editable=False, on_delete=django.db.models.deletion.CASCADE, to='uncloud_net.VPNNetworkReservation'),
),
]

View file

@ -112,6 +112,7 @@ class VPNNetworkReservation(UncloudModel):
address = models.GenericIPAddressField(primary_key=True) address = models.GenericIPAddressField(primary_key=True)
status = models.CharField(max_length=256, status = models.CharField(max_length=256,
default='used',
choices = ( choices = (
('used', 'used'), ('used', 'used'),
('free', 'free') ('free', 'free')

View file

@ -0,0 +1,27 @@
# Generated by Django 3.0.3 on 2020-03-05 15:24
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('uncloud_pay', '0001_initial'),
]
operations = [
migrations.RenameField(
model_name='paymentmethod',
old_name='stripe_card_id',
new_name='stripe_payment_method_id',
),
migrations.AddField(
model_name='paymentmethod',
name='stripe_setup_intent_id',
field=models.CharField(blank=True, max_length=32, null=True),
),
migrations.AlterUniqueTogether(
name='paymentmethod',
unique_together=set(),
),
]

View file

@ -0,0 +1,18 @@
# Generated by Django 3.0.3 on 2020-03-05 13:54
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('uncloud_pay', '0002_auto_20200305_1524'),
]
operations = [
migrations.AlterField(
model_name='paymentmethod',
name='primary',
field=models.BooleanField(default=False, editable=False),
),
]

View file

@ -0,0 +1,23 @@
# Generated by Django 3.0.5 on 2020-04-09 12:25
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('uncloud_pay', '0003_auto_20200305_1354'),
]
operations = [
migrations.AlterField(
model_name='order',
name='recurring_period',
field=models.CharField(choices=[('ONCE', 'Onetime'), ('YEAR', 'Per Year'), ('MONTH', 'Per Month'), ('MINUTE', 'Per Minute'), ('WEEK', 'Per Week'), ('DAY', 'Per Day'), ('HOUR', 'Per Hour'), ('SECOND', 'Per Second')], default='MONTH', max_length=32),
),
migrations.AlterField(
model_name='order',
name='starting_date',
field=models.DateTimeField(),
),
]

View file

@ -4,11 +4,10 @@ from django.contrib.auth import get_user_model
from django.core.validators import MinValueValidator from django.core.validators import MinValueValidator
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.utils import timezone from django.utils import timezone
from django.dispatch import receiver
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
import django.db.models.signals as signals
import uuid import uuid
import logging
from functools import reduce from functools import reduce
from math import ceil from math import ceil
from datetime import timedelta from datetime import timedelta
@ -21,16 +20,29 @@ from uncloud_pay.helpers import beginning_of_month, end_of_month
from uncloud import AMOUNT_DECIMALS, AMOUNT_MAX_DIGITS from uncloud import AMOUNT_DECIMALS, AMOUNT_MAX_DIGITS
from uncloud.models import UncloudModel, UncloudStatus from uncloud.models import UncloudModel, UncloudStatus
from decimal import Decimal
import decimal
# Define DecimalField properties, used to represent amounts of money.
AMOUNT_MAX_DIGITS=10
AMOUNT_DECIMALS=2
# FIXME: check why we need +1 here.
decimal.getcontext().prec = AMOUNT_DECIMALS + 1
# Used to generate bill due dates. # Used to generate bill due dates.
BILL_PAYMENT_DELAY=timedelta(days=10) BILL_PAYMENT_DELAY=timedelta(days=10)
# Initialize logger.
logger = logging.getLogger(__name__)
# See https://docs.djangoproject.com/en/dev/ref/models/fields/#field-choices-enum-types # See https://docs.djangoproject.com/en/dev/ref/models/fields/#field-choices-enum-types
class RecurringPeriod(models.TextChoices): class RecurringPeriod(models.TextChoices):
ONE_TIME = 'ONCE', _('Onetime') ONE_TIME = 'ONCE', _('Onetime')
PER_YEAR = 'YEAR', _('Per Year') PER_YEAR = 'YEAR', _('Per Year')
PER_MONTH = 'MONTH', _('Per Month') PER_MONTH = 'MONTH', _('Per Month')
PER_MINUTE = 'MINUTE', _('Per Minute') PER_MINUTE = 'MINUTE', _('Per Minute')
PER_WEEK = 'WEEK', _('Per Week')
PER_DAY = 'DAY', _('Per Day') PER_DAY = 'DAY', _('Per Day')
PER_HOUR = 'HOUR', _('Per Hour') PER_HOUR = 'HOUR', _('Per Hour')
PER_SECOND = 'SECOND', _('Per Second') PER_SECOND = 'SECOND', _('Per Second')
@ -106,57 +118,76 @@ class PaymentMethod(models.Model):
), ),
default='stripe') default='stripe')
description = models.TextField() description = models.TextField()
primary = models.BooleanField(default=True) primary = models.BooleanField(default=False, editable=False)
# Only used for "Stripe" source # Only used for "Stripe" source
stripe_card_id = models.CharField(max_length=32, blank=True, null=True) stripe_payment_method_id = models.CharField(max_length=32, blank=True, null=True)
stripe_setup_intent_id = models.CharField(max_length=32, blank=True, null=True)
@property @property
def stripe_card_last4(self): def stripe_card_last4(self):
if self.source == 'stripe': if self.source == 'stripe' and self.active:
card_request = uncloud_pay.stripe.get_card( payment_method = uncloud_pay.stripe.get_payment_method(
StripeCustomer.objects.get(owner=self.owner).stripe_id, self.stripe_payment_method_id)
self.stripe_card_id) return payment_method.card.last4
if card_request['error'] == None:
return card_request['response_object']['last4']
else:
return None
else: else:
return None return None
@property
def active(self):
if self.source == 'stripe' and self.stripe_payment_method_id != None:
return True
else:
return False
def charge(self, amount): def charge(self, amount):
if amount > 0: # Make sure we don't charge negative amount by errors... if not self.active:
if self.source == 'stripe': raise Exception('This payment method is inactive.')
stripe_customer = StripeCustomer.objects.get(owner=self.owner).stripe_id
charge_request = uncloud_pay.stripe.charge_customer(amount, stripe_customer, self.stripe_card_id)
if charge_request['error'] == None:
payment = Payment(owner=self.owner, source=self.source, amount=amount)
payment.save() # TODO: Check return status
return payment if amount < 0: # Make sure we don't charge negative amount by errors...
else:
raise Exception('Stripe error: {}'.format(charge_request['error']))
else:
raise Exception('This payment method is unsupported/cannot be charged.')
else:
raise Exception('Cannot charge negative amount.') raise Exception('Cannot charge negative amount.')
if self.source == 'stripe':
stripe_customer = StripeCustomer.objects.get(owner=self.owner).stripe_id
stripe_payment = uncloud_pay.stripe.charge_customer(
amount, stripe_customer, self.stripe_payment_method_id)
if 'paid' in stripe_payment and stripe_payment['paid'] == False:
raise Exception(stripe_payment['error'])
else:
payment = Payment.objects.create(
owner=self.owner, source=self.source, amount=amount)
return payment
else:
raise Exception('This payment method is unsupported/cannot be charged.')
def set_as_primary_for(self, user):
methods = PaymentMethod.objects.filter(owner=user, primary=True)
for method in methods:
print(method)
method.primary = False
method.save()
self.primary = True
self.save()
def get_primary_for(user): def get_primary_for(user):
methods = PaymentMethod.objects.filter(owner=user) methods = PaymentMethod.objects.filter(owner=user)
for method in methods: for method in methods:
# Do we want to do something with non-primary method? # Do we want to do something with non-primary method?
if method.primary: if method.active and method.primary:
return method return method
return None return None
class Meta: class Meta:
unique_together = [['owner', 'primary']] # TODO: limit to one primary method per user.
# unique_together is no good since it won't allow more than one
# non-primary method.
pass
### ###
# Bills & Payments. # Bills.
class Bill(models.Model): class Bill(models.Model):
uuid = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) uuid = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
@ -199,51 +230,108 @@ class Bill(models.Model):
@staticmethod @staticmethod
def generate_for(year, month, user): def generate_for(year, month, user):
# /!\ We exclusively work on the specified year and month. # /!\ We exclusively work on the specified year and month.
generated_bills = []
# Default values for next bill (if any). Only saved at the end of # Default values for next bill (if any).
# this method, if relevant. starting_date=beginning_of_month(year, month)
next_bill = Bill(owner=user, ending_date=end_of_month(year, month)
starting_date=beginning_of_month(year, month), creation_date=timezone.now()
ending_date=end_of_month(year, month),
creation_date=timezone.now(),
due_date=timezone.now() + BILL_PAYMENT_DELAY)
# Select all orders active on the request period. # Select all orders active on the request period (i.e. starting on or after starting_date).
orders = Order.objects.filter( orders = Order.objects.filter(
Q(ending_date__gt=next_bill.starting_date) | Q(ending_date__isnull=True), Q(ending_date__gte=starting_date) | Q(ending_date__isnull=True),
owner=user) owner=user)
# Check if there is already a bill covering the order and period pair: # Check if there is already a bill covering the order and period pair:
# * Get latest bill by ending_date: previous_bill.ending_date # * Get latest bill by ending_date: previous_bill.ending_date
# * If previous_bill.ending_date is before next_bill.ending_date, a new # * For monthly bills: if previous_bill.ending_date is before
# bill has to be generated. # (next_bill) ending_date, a new bill has to be generated.
unpaid_orders = [] # * For yearly bill: if previous_bill.ending_date is on working
# month, generate new bill.
unpaid_orders = { 'monthly_or_less': [], 'yearly': {}}
for order in orders: for order in orders:
try: try:
previous_bill = order.bill.latest('ending_date') previous_bill = order.bill.latest('ending_date')
except ObjectDoesNotExist: except ObjectDoesNotExist:
previous_bill = None previous_bill = None
if previous_bill == None or previous_bill.ending_date < next_bill.ending_date: # FIXME: control flow is confusing in this block.
unpaid_orders.append(order) if order.recurring_period == RecurringPeriod.PER_YEAR:
# We ignore anything smaller than a day in here.
next_yearly_bill_start_on = None
if previous_bill == None:
next_yearly_bill_start_on = order.starting_date
elif previous_bill.ending_date <= ending_date:
next_yearly_bill_start_on = (previous_bill.ending_date + timedelta(days=1))
# Commit next_bill if it there are 'unpaid' orders. # Store for bill generation. One bucket per day of month with a starting bill.
if len(unpaid_orders) > 0: # bucket is a reference here, no need to reassign.
next_bill.save() if next_yearly_bill_start_on:
# We want to group orders by date but keep using datetimes.
next_yearly_bill_start_on = next_yearly_bill_start_on.replace(
minute=0, hour=0, second=0, microsecond=0)
bucket = unpaid_orders['yearly'].get(next_yearly_bill_start_on)
if bucket == None:
unpaid_orders['yearly'][next_yearly_bill_start_on] = [order]
else:
unpaid_orders['yearly'][next_yearly_bill_start_on] = bucket + [order]
else:
if previous_bill == None or previous_bill.ending_date <= ending_date:
unpaid_orders['monthly_or_less'].append(order)
# Handle working month's billing.
if len(unpaid_orders['monthly_or_less']) > 0:
# TODO: PREPAID billing is not supported yet.
prepaid_due_date = min(creation_date, starting_date) + BILL_PAYMENT_DELAY
postpaid_due_date = max(creation_date, ending_date) + BILL_PAYMENT_DELAY
next_monthly_bill = Bill.objects.create(owner=user,
creation_date=creation_date,
starting_date=starting_date, # FIXME: this is a hack!
ending_date=ending_date,
due_date=postpaid_due_date)
# It is not possible to register many-to-many relationship before # It is not possible to register many-to-many relationship before
# the two end-objects are saved in database. # the two end-objects are saved in database.
for order in unpaid_orders: for order in unpaid_orders['monthly_or_less']:
order.bill.add(next_bill) order.bill.add(next_monthly_bill)
# TODO: use logger. logger.info("Generated monthly bill {} (amount: {}) for user {}."
print("Generated bill {} (amount: {}) for user {}." .format(next_monthly_bill.uuid, next_monthly_bill.total, user))
.format(next_bill.uuid, next_bill.total, user))
return next_bill # Add to output.
generated_bills.append(next_monthly_bill)
# Return None if no bill was created. # Handle yearly bills starting on working month.
return None if len(unpaid_orders['yearly']) > 0:
# For every starting date, generate new bill.
for next_yearly_bill_start_on in unpaid_orders['yearly']:
# No postpaid for yearly payments.
prepaid_due_date = min(creation_date, next_yearly_bill_start_on) + BILL_PAYMENT_DELAY
# Bump by one year, remove one day.
ending_date = next_yearly_bill_start_on.replace(
year=next_yearly_bill_start_on.year+1) - timedelta(days=1)
next_yearly_bill = Bill.objects.create(owner=user,
creation_date=creation_date,
starting_date=next_yearly_bill_start_on,
ending_date=ending_date,
due_date=prepaid_due_date)
# It is not possible to register many-to-many relationship before
# the two end-objects are saved in database.
for order in unpaid_orders['yearly'][next_yearly_bill_start_on]:
order.bill.add(next_yearly_bill)
logger.info("Generated yearly bill {} (amount: {}) for user {}."
.format(next_yearly_bill.uuid, next_yearly_bill.total, user))
# Add to output.
generated_bills.append(next_yearly_bill)
# Return generated (monthly + yearly) bills.
return generated_bills
@staticmethod @staticmethod
def get_unpaid_for(user): def get_unpaid_for(user):
@ -286,7 +374,7 @@ class BillRecord():
self.recurring_period = order_record.recurring_period self.recurring_period = order_record.recurring_period
self.description = order_record.description self.description = order_record.description
if self.order.starting_date > self.bill.starting_date: if self.order.starting_date >= self.bill.starting_date:
self.one_time_price = order_record.one_time_price self.one_time_price = order_record.one_time_price
else: else:
self.one_time_price = 0 self.one_time_price = 0
@ -295,7 +383,7 @@ class BillRecord():
def recurring_count(self): def recurring_count(self):
# Compute billing delta. # Compute billing delta.
billed_until = self.bill.ending_date billed_until = self.bill.ending_date
if self.order.ending_date != None and self.order.ending_date < self.order.ending_date: if self.order.ending_date != None and self.order.ending_date <= self.bill.ending_date:
billed_until = self.order.ending_date billed_until = self.order.ending_date
billed_from = self.bill.starting_date billed_from = self.bill.starting_date
@ -303,7 +391,7 @@ class BillRecord():
billed_from = self.order.starting_date billed_from = self.order.starting_date
if billed_from > billed_until: if billed_from > billed_until:
# TODO: think about and check edges cases. This should not be # TODO: think about and check edge cases. This should not be
# possible. # possible.
raise Exception('Impossible billing delta!') raise Exception('Impossible billing delta!')
@ -311,11 +399,14 @@ class BillRecord():
# TODO: refactor this thing? # TODO: refactor this thing?
# TODO: weekly # TODO: weekly
# TODO: yearly if self.recurring_period == RecurringPeriod.PER_YEAR:
if self.recurring_period == RecurringPeriod.PER_MONTH: # XXX: Should always be one => we do not bill for more than one year.
# TODO: check billed_delta is ~365 days.
return 1
elif self.recurring_period == RecurringPeriod.PER_MONTH:
days = ceil(billed_delta / timedelta(days=1)) days = ceil(billed_delta / timedelta(days=1))
# XXX: we assume monthly bills for now. # Monthly bills always cover one single month.
if (self.bill.starting_date.year != self.bill.starting_date.year or if (self.bill.starting_date.year != self.bill.starting_date.year or
self.bill.starting_date.month != self.bill.ending_date.month): self.bill.starting_date.month != self.bill.ending_date.month):
raise Exception('Bill {} covers more than one month. Cannot bill PER_MONTH.'. raise Exception('Bill {} covers more than one month. Cannot bill PER_MONTH.'.
@ -325,25 +416,28 @@ class BillRecord():
(_, days_in_month) = monthrange( (_, days_in_month) = monthrange(
self.bill.starting_date.year, self.bill.starting_date.year,
self.bill.starting_date.month) self.bill.starting_date.month)
return Decimal(days / days_in_month) return days / days_in_month
elif self.recurring_period == RecurringPeriod.PER_WEEK:
weeks = ceil(billed_delta / timedelta(week=1))
return weeks
elif self.recurring_period == RecurringPeriod.PER_DAY: elif self.recurring_period == RecurringPeriod.PER_DAY:
days = ceil(billed_delta / timedelta(days=1)) days = ceil(billed_delta / timedelta(days=1))
return Decimal(days) return days
elif self.recurring_period == RecurringPeriod.PER_HOUR: elif self.recurring_period == RecurringPeriod.PER_HOUR:
hours = ceil(billed_delta / timedelta(hours=1)) hours = ceil(billed_delta / timedelta(hours=1))
return Decimal(hours) return hours
elif self.recurring_period == RecurringPeriod.PER_SECOND: elif self.recurring_period == RecurringPeriod.PER_SECOND:
seconds = ceil(billed_delta / timedelta(seconds=1)) seconds = ceil(billed_delta / timedelta(seconds=1))
return Decimal(seconds) return seconds
elif self.recurring_period == RecurringPeriod.ONE_TIME: elif self.recurring_period == RecurringPeriod.ONE_TIME:
return Decimal(0) return 0
else: else:
raise Exception('Unsupported recurring period: {}.'. raise Exception('Unsupported recurring period: {}.'.
format(record.recurring_period)) format(record.recurring_period))
@property @property
def amount(self): def amount(self):
return self.recurring_price * self.recurring_count + self.one_time_price return Decimal(float(self.recurring_price) * self.recurring_count) + self.one_time_price
### ###
# Orders. # Orders.
@ -358,7 +452,7 @@ class Order(models.Model):
# TODO: enforce ending_date - starting_date to be larger than recurring_period. # TODO: enforce ending_date - starting_date to be larger than recurring_period.
creation_date = models.DateTimeField(auto_now_add=True) creation_date = models.DateTimeField(auto_now_add=True)
starting_date = models.DateTimeField(auto_now_add=True) starting_date = models.DateTimeField()
ending_date = models.DateTimeField(blank=True, ending_date = models.DateTimeField(blank=True,
null=True) null=True)

View file

@ -8,30 +8,28 @@ from .models import *
class PaymentSerializer(serializers.ModelSerializer): class PaymentSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Payment model = Payment
fields = ['owner', 'amount', 'source', 'timestamp'] fields = '__all__'
class PaymentMethodSerializer(serializers.ModelSerializer): class PaymentMethodSerializer(serializers.ModelSerializer):
stripe_card_last4 = serializers.IntegerField() stripe_card_last4 = serializers.IntegerField()
class Meta: class Meta:
model = PaymentMethod model = PaymentMethod
fields = ['uuid', 'source', 'description', 'primary', 'stripe_card_last4'] fields = ['uuid', 'source', 'description', 'primary', 'stripe_card_last4', 'active']
class UpdatePaymentMethodSerializer(serializers.ModelSerializer):
class Meta:
model = PaymentMethod
fields = ['description', 'primary']
class ChargePaymentMethodSerializer(serializers.Serializer): class ChargePaymentMethodSerializer(serializers.Serializer):
amount = serializers.DecimalField(max_digits=10, decimal_places=2) amount = serializers.DecimalField(max_digits=10, decimal_places=2)
class CreditCardSerializer(serializers.Serializer):
number = serializers.IntegerField()
exp_month = serializers.IntegerField()
exp_year = serializers.IntegerField()
cvc = serializers.IntegerField()
class CreatePaymentMethodSerializer(serializers.ModelSerializer): class CreatePaymentMethodSerializer(serializers.ModelSerializer):
credit_card = CreditCardSerializer() please_visit = serializers.CharField(read_only=True)
class Meta: class Meta:
model = PaymentMethod model = PaymentMethod
fields = ['source', 'description', 'primary', 'credit_card'] fields = ['source', 'description', 'primary', 'please_visit']
### ###
# Orders & Products. # Orders & Products.

View file

@ -10,9 +10,14 @@ import uncloud.secrets
# Static stripe configuration used below. # Static stripe configuration used below.
CURRENCY = 'chf' CURRENCY = 'chf'
# README: We use the Payment Intent API as described on
# https://stripe.com/docs/payments/save-and-reuse
# For internal use only.
stripe.api_key = uncloud.secrets.STRIPE_KEY stripe.api_key = uncloud.secrets.STRIPE_KEY
# Helper (decorator) used to catch errors raised by stripe logic. # Helper (decorator) used to catch errors raised by stripe logic.
# Catch errors that should not be displayed to the end user, raise again.
def handle_stripe_error(f): def handle_stripe_error(f):
def handle_problems(*args, **kwargs): def handle_problems(*args, **kwargs):
response = { response = {
@ -21,108 +26,84 @@ def handle_stripe_error(f):
'error': None 'error': None
} }
common_message = "Currently it is not possible to make payments." common_message = "Currently it is not possible to make payments. Please try agin later."
try: try:
response_object = f(*args, **kwargs) response_object = f(*args, **kwargs)
response = { return response_object
'response_object': response_object,
'error': None
}
return response
except stripe.error.CardError as e: except stripe.error.CardError as e:
# Since it's a decline, stripe.error.CardError will be caught # Since it's a decline, stripe.error.CardError will be caught
body = e.json_body body = e.json_body
err = body['error']
response.update({'error': err['message']})
logging.error(str(e)) logging.error(str(e))
return response
raise e # For error handling.
except stripe.error.RateLimitError: except stripe.error.RateLimitError:
response.update( logging.error("Too many requests made to the API too quickly.")
{'error': "Too many requests made to the API too quickly"}) raise Exception(common_message)
return response
except stripe.error.InvalidRequestError as e: except stripe.error.InvalidRequestError as e:
logging.error(str(e)) logging.error(str(e))
response.update({'error': "Invalid parameters"}) raise Exception('Invalid parameters.')
return response
except stripe.error.AuthenticationError as e: except stripe.error.AuthenticationError as e:
# Authentication with Stripe's API failed # Authentication with Stripe's API failed
# (maybe you changed API keys recently) # (maybe you changed API keys recently)
logging.error(str(e)) logging.error(str(e))
response.update({'error': common_message}) raise Exception(common_message)
return response
except stripe.error.APIConnectionError as e: except stripe.error.APIConnectionError as e:
logging.error(str(e)) logging.error(str(e))
response.update({'error': common_message}) raise Exception(common_message)
return response
except stripe.error.StripeError as e: except stripe.error.StripeError as e:
# maybe send email # XXX: maybe send email
logging.error(str(e)) logging.error(str(e))
response.update({'error': common_message}) raise Exception(common_message)
return response
except Exception as e: except Exception as e:
# maybe send email # maybe send email
logging.error(str(e)) logging.error(str(e))
response.update({'error': common_message}) raise Exception(common_message)
return response
return handle_problems return handle_problems
# Convenience CC container, also used for serialization.
class CreditCard():
number = None
exp_year = None
exp_month = None
cvc = None
def __init__(self, number, exp_month, exp_year, cvc):
self.number=number
self.exp_year = exp_year
self.exp_month = exp_month
self.cvc = cvc
# Actual Stripe logic. # Actual Stripe logic.
def public_api_key():
return uncloud.secrets.STRIPE_PUBLIC_KEY
def get_customer_id_for(user): def get_customer_id_for(user):
try: try:
# .get() raise if there is no matching entry. # .get() raise if there is no matching entry.
return uncloud_pay.models.StripeCustomer.objects.get(owner=user).stripe_id return uncloud_pay.models.StripeCustomer.objects.get(owner=user).stripe_id
except ObjectDoesNotExist: except ObjectDoesNotExist:
# No entry yet - making a new one. # No entry yet - making a new one.
customer_request = create_customer(user.username, user.email) try:
if customer_request['error'] == None: customer = create_customer(user.username, user.email)
mapping = uncloud_pay.models.StripeCustomer.objects.create( uncloud_stripe_mapping = uncloud_pay.models.StripeCustomer.objects.create(
owner=user, owner=user, stripe_id=customer.id)
stripe_id=customer_request['response_object']['id'] return uncloud_stripe_mapping.stripe_id
) except Exception as e:
return mapping.stripe_id
else:
return None return None
@handle_stripe_error @handle_stripe_error
def create_card(customer_id, credit_card): def create_setup_intent(customer_id):
return stripe.Customer.create_source( return stripe.SetupIntent.create(customer=customer_id)
customer_id,
card={
'number': credit_card.number,
'exp_month': credit_card.exp_month,
'exp_year': credit_card.exp_year,
'cvc': credit_card.cvc
})
@handle_stripe_error @handle_stripe_error
def get_card(customer_id, card_id): def get_setup_intent(setup_intent_id):
return stripe.Customer.retrieve_source(customer_id, card_id) return stripe.SetupIntent.retrieve(setup_intent_id)
def get_payment_method(payment_method_id):
return stripe.PaymentMethod.retrieve(payment_method_id)
@handle_stripe_error @handle_stripe_error
def charge_customer(amount, customer_id, card_id): def charge_customer(amount, customer_id, card_id):
# Amount is in CHF but stripes requires smallest possible unit. # Amount is in CHF but stripes requires smallest possible unit.
# See https://stripe.com/docs/api/charges/create # https://stripe.com/docs/api/payment_intents/create#create_payment_intent-amount
adjusted_amount = int(amount * 100) adjusted_amount = int(amount * 100)
return stripe.Charge.create( return stripe.PaymentIntent.create(
amount=adjusted_amount, amount=adjusted_amount,
currency=CURRENCY, currency=CURRENCY,
customer=customer_id, customer=customer_id,
source=card_id) payment_method=card_id,
off_session=True,
confirm=True,
)
@handle_stripe_error @handle_stripe_error
def create_customer(name, email): def create_customer(name, email):

View file

@ -0,0 +1,18 @@
<!DOCTYPE html>
<html>
<head>
<title>Error</title>
<style>
#content {
width: 400px;
margin: auto;
}
</style>
</head>
<body>
<div id="content">
<h1>Error</h1>
<p>{{ error }}</p>
</div>
</body>
</html>

View file

@ -0,0 +1,76 @@
<!DOCTYPE html>
<html>
<head>
<title>Stripe Card Registration</title>
<!-- https://stripe.com/docs/js/appendix/viewport_meta_requirements -->
<meta name="viewport" content="width=device-width, initial-scale=1" />
<script src="https://js.stripe.com/v3/"></script>
<style>
#content {
width: 400px;
margin: auto;
}
#callback-form {
display: none;
}
</style>
</head>
<body>
<div id="content">
<h1>Registering Stripe Credit Card</h1>
<!-- Stripe form and messages -->
<span id="message"></span>
<form id="setup-form">
<div id="card-element"></div>
<button type='button' id="card-button">
Save
</button>
</form>
<!-- Dirty hack used for callback to API -->
<form id="callback-form" action="{{ callback }}" method="post"></form>
</div>
<!-- Enable Stripe from UI elements -->
<script>
var stripe = Stripe('{{ stripe_pk }}');
var elements = stripe.elements();
var cardElement = elements.create('card');
cardElement.mount('#card-element');
</script>
<!-- Handle card submission -->
<script>
var cardButton = document.getElementById('card-button');
var messageContainer = document.getElementById('message');
var clientSecret = '{{ client_secret }}';
cardButton.addEventListener('click', function(ev) {
stripe.confirmCardSetup(
clientSecret,
{
payment_method: {
card: cardElement,
billing_details: {
},
},
}
).then(function(result) {
if (result.error) {
var message = document.createTextNode('Error:' + result.error.message);
messageContainer.appendChild(message);
} else {
// Return to API on success.
document.getElementById("callback-form").submit();
}
});
});
</script>
</body>
</html>

View file

@ -1,3 +1,118 @@
from django.test import TestCase from django.test import TestCase
from django.contrib.auth import get_user_model
from datetime import datetime, date, timedelta
# Create your tests here. from .models import *
class BillingTestCase(TestCase):
def setUp(self):
self.user = get_user_model().objects.create(
username='jdoe',
email='john.doe@domain.tld')
def test_truth(self):
self.assertEqual(1+1, 2)
def test_basic_monthly_billing(self):
one_time_price = 10
recurring_price = 20
description = "Test Product 1"
# Three months: full, full, partial.
starting_date = datetime.fromisoformat('2020-03-01')
ending_date = datetime.fromisoformat('2020-05-08')
# Create order to be billed.
order = Order.objects.create(
owner=self.user,
starting_date=starting_date,
ending_date=ending_date,
recurring_period=RecurringPeriod.PER_MONTH)
order.add_record(one_time_price, recurring_price, description)
# Generate & check bill for first month: full recurring_price + setup.
first_month_bills = Bill.generate_for(2020, 3, self.user)
self.assertEqual(len(first_month_bills), 1)
self.assertEqual(first_month_bills[0].total, one_time_price + recurring_price)
# Generate & check bill for second month: full recurring_price.
second_month_bills = Bill.generate_for(2020, 4, self.user)
self.assertEqual(len(second_month_bills), 1)
self.assertEqual(second_month_bills[0].total, recurring_price)
# Generate & check bill for third and last month: partial recurring_price.
third_month_bills = Bill.generate_for(2020, 5, self.user)
self.assertEqual(len(third_month_bills), 1)
# 31 days in May.
self.assertEqual(float(third_month_bills[0].total),
round((7/31) * recurring_price, AMOUNT_DECIMALS))
# Check that running Bill.generate_for() twice does not create duplicates.
self.assertEqual(len(Bill.generate_for(2020, 3, self.user)), 0)
def test_basic_yearly_billing(self):
one_time_price = 10
recurring_price = 150
description = "Test Product 1"
starting_date = datetime.fromisoformat('2020-03-31T08:05:23')
# Create order to be billed.
order = Order.objects.create(
owner=self.user,
starting_date=starting_date,
recurring_period=RecurringPeriod.PER_YEAR)
order.add_record(one_time_price, recurring_price, description)
# Generate & check bill for first year: recurring_price + setup.
first_year_bills = Bill.generate_for(2020, 3, self.user)
self.assertEqual(len(first_year_bills), 1)
self.assertEqual(first_year_bills[0].starting_date.date(),
date.fromisoformat('2020-03-31'))
self.assertEqual(first_year_bills[0].ending_date.date(),
date.fromisoformat('2021-03-30'))
self.assertEqual(first_year_bills[0].total,
recurring_price + one_time_price)
# Generate & check bill for second year: recurring_price.
second_year_bills = Bill.generate_for(2021, 3, self.user)
self.assertEqual(len(second_year_bills), 1)
self.assertEqual(second_year_bills[0].starting_date.date(),
date.fromisoformat('2021-03-31'))
self.assertEqual(second_year_bills[0].ending_date.date(),
date.fromisoformat('2022-03-30'))
self.assertEqual(second_year_bills[0].total, recurring_price)
# Check that running Bill.generate_for() twice does not create duplicates.
self.assertEqual(len(Bill.generate_for(2020, 3, self.user)), 0)
self.assertEqual(len(Bill.generate_for(2020, 4, self.user)), 0)
self.assertEqual(len(Bill.generate_for(2020, 2, self.user)), 0)
self.assertEqual(len(Bill.generate_for(2021, 3, self.user)), 0)
def test_basic_hourly_billing(self):
one_time_price = 10
recurring_price = 1.4
description = "Test Product 1"
starting_date = datetime.fromisoformat('2020-03-31T08:05:23')
ending_date = datetime.fromisoformat('2020-04-01T11:13:32')
# Create order to be billed.
order = Order.objects.create(
owner=self.user,
starting_date=starting_date,
ending_date=ending_date,
recurring_period=RecurringPeriod.PER_HOUR)
order.add_record(one_time_price, recurring_price, description)
# Generate & check bill for first month: recurring_price + setup.
first_month_bills = Bill.generate_for(2020, 3, self.user)
self.assertEqual(len(first_month_bills), 1)
self.assertEqual(float(first_month_bills[0].total),
round(16 * recurring_price, AMOUNT_DECIMALS) + one_time_price)
# Generate & check bill for first month: recurring_price.
second_month_bills = Bill.generate_for(2020, 4, self.user)
self.assertEqual(len(second_month_bills), 1)
self.assertEqual(float(second_month_bills[0].total),
round(12 * recurring_price, AMOUNT_DECIMALS))

View file

@ -1,9 +1,12 @@
from django.shortcuts import render from django.shortcuts import render
from django.db import transaction from django.db import transaction
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from rest_framework import viewsets, permissions, status from rest_framework import viewsets, permissions, status, views
from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.reverse import reverse
from rest_framework.decorators import renderer_classes
import json import json
@ -13,26 +16,7 @@ from datetime import datetime
import uncloud_pay.stripe as uncloud_stripe import uncloud_pay.stripe as uncloud_stripe
### ###
# Standard user views: # Payments and Payment Methods.
class BalanceViewSet(viewsets.ViewSet):
# here we return a number
# number = sum(payments) - sum(bills)
#bills = Bill.objects.filter(owner=self.request.user)
#payments = Payment.objects.filter(owner=self.request.user)
# sum_paid = sum([ amount for amount payments..,. ]) # you get the picture
# sum_to_be_paid = sum([ amount for amount bills..,. ]) # you get the picture
pass
class BillViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = BillSerializer
permission_classes = [permissions.IsAuthenticated]
def get_queryset(self):
return Bill.objects.filter(owner=self.request.user)
class PaymentViewSet(viewsets.ReadOnlyModelViewSet): class PaymentViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = PaymentSerializer serializer_class = PaymentSerializer
@ -48,19 +32,19 @@ class OrderViewSet(viewsets.ReadOnlyModelViewSet):
def get_queryset(self): def get_queryset(self):
return Order.objects.filter(owner=self.request.user) return Order.objects.filter(owner=self.request.user)
class PaymentMethodViewSet(viewsets.ModelViewSet): class PaymentMethodViewSet(viewsets.ModelViewSet):
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
def get_serializer_class(self): def get_serializer_class(self):
if self.action == 'create': if self.action == 'create':
return CreatePaymentMethodSerializer return CreatePaymentMethodSerializer
elif self.action == 'update':
return UpdatePaymentMethodSerializer
elif self.action == 'charge': elif self.action == 'charge':
return ChargePaymentMethodSerializer return ChargePaymentMethodSerializer
else: else:
return PaymentMethodSerializer return PaymentMethodSerializer
def get_queryset(self): def get_queryset(self):
return PaymentMethod.objects.filter(owner=self.request.user) return PaymentMethod.objects.filter(owner=self.request.user)
@ -70,28 +54,38 @@ class PaymentMethodViewSet(viewsets.ModelViewSet):
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
# Retrieve Stripe customer ID for user. # Set newly created method as primary if no other method is.
customer_id = uncloud_stripe.get_customer_id_for(request.user) if PaymentMethod.get_primary_for(request.user) == None:
if customer_id == None: serializer.validated_data['primary'] = True
return Response(
if serializer.validated_data['source'] == "stripe":
# Retrieve Stripe customer ID for user.
customer_id = uncloud_stripe.get_customer_id_for(request.user)
if customer_id == None:
return Response(
{'error': 'Could not resolve customer stripe ID.'}, {'error': 'Could not resolve customer stripe ID.'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR) status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Register card under stripe customer. try:
credit_card = uncloud_stripe.CreditCard(**serializer.validated_data.pop('credit_card')) setup_intent = uncloud_stripe.create_setup_intent(customer_id)
card_request = uncloud_stripe.create_card(customer_id, credit_card) except Exception as e:
if card_request['error']: return Response({'error': str(e)},
return Response({'stripe_error': card_request['error']}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) status=status.HTTP_500_INTERNAL_SERVER_ERROR)
card_id = card_request['response_object']['id']
# Save payment method locally. payment_method = PaymentMethod.objects.create(
serializer.validated_data['stripe_card_id'] = card_request['response_object']['id'] owner=request.user,
payment_method = PaymentMethod.objects.create(owner=request.user, **serializer.validated_data) stripe_setup_intent_id=setup_intent.id,
**serializer.validated_data)
# We do not want to return the credit card details sent with the POST # TODO: find a way to use reverse properly:
# request. # https://www.django-rest-framework.org/api-guide/reverse/
output_serializer = PaymentMethodSerializer(payment_method) path = "payment-method/{}/register-stripe-cc".format(
return Response(output_serializer.data) payment_method.uuid)
stripe_registration_url = reverse('api-root', request=request) + path
return Response({'please_visit': stripe_registration_url})
else:
serializer.save(owner=request.user, **serializer.validated_data)
return Response(serializer.data)
@action(detail=True, methods=['post']) @action(detail=True, methods=['post'])
def charge(self, request, pk=None): def charge(self, request, pk=None):
@ -106,8 +100,96 @@ class PaymentMethodViewSet(viewsets.ModelViewSet):
except Exception as e: except Exception as e:
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=True, methods=['get'], url_path='register-stripe-cc', renderer_classes=[TemplateHTMLRenderer])
def register_stripe_cc(self, request, pk=None):
payment_method = self.get_object()
if payment_method.source != 'stripe':
return Response(
{'error': 'This is not a Stripe-based payment method.'},
template_name='error.html.j2')
if payment_method.active:
return Response(
{'error': 'This payment method is already active'},
template_name='error.html.j2')
try:
setup_intent = uncloud_stripe.get_setup_intent(
payment_method.stripe_setup_intent_id)
except Exception as e:
return Response(
{'error': str(e)},
template_name='error.html.j2')
# TODO: find a way to use reverse properly:
# https://www.django-rest-framework.org/api-guide/reverse/
callback_path= "payment-method/{}/activate-stripe-cc/".format(
payment_method.uuid)
callback = reverse('api-root', request=request) + callback_path
# Render stripe card registration form.
template_args = {
'client_secret': setup_intent.client_secret,
'stripe_pk': uncloud_stripe.public_api_key,
'callback': callback
}
return Response(template_args, template_name='stripe-payment.html.j2')
@action(detail=True, methods=['post'], url_path='activate-stripe-cc')
def activate_stripe_cc(self, request, pk=None):
payment_method = self.get_object()
try:
setup_intent = uncloud_stripe.get_setup_intent(
payment_method.stripe_setup_intent_id)
except Exception as e:
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Card had been registered, fetching payment method.
print(setup_intent)
if setup_intent.payment_method:
payment_method.stripe_payment_method_id = setup_intent.payment_method
payment_method.save()
return Response({
'uuid': payment_method.uuid,
'activated': payment_method.active})
else:
error = 'Could not fetch payment method from stripe. Please try again.'
return Response({'error': error})
@action(detail=True, methods=['post'], url_path='set-as-primary')
def set_as_primary(self, request, pk=None):
payment_method = self.get_object()
payment_method.set_as_primary_for(request.user)
serializer = self.get_serializer(payment_method)
return Response(serializer.data)
### ###
# Admin views. # Bills and Orders.
class BillViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = BillSerializer
permission_classes = [permissions.IsAuthenticated]
def get_queryset(self):
return Bill.objects.filter(owner=self.request.user)
def unpaid(self, request):
return Bill.objects.filter(owner=self.request.user, paid=False)
class OrderViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = OrderSerializer
permission_classes = [permissions.IsAuthenticated]
def get_queryset(self):
return Order.objects.filter(owner=self.request.user)
###
# Old admin stuff.
class AdminPaymentViewSet(viewsets.ModelViewSet): class AdminPaymentViewSet(viewsets.ModelViewSet):
serializer_class = PaymentSerializer serializer_class = PaymentSerializer

View file

@ -76,6 +76,8 @@ class VMProduct(Product):
return self.cores * 3 + self.ram_in_gb * 4 return self.cores * 3 + self.ram_in_gb * 4
elif recurring_period == RecurringPeriod.PER_HOUR: elif recurring_period == RecurringPeriod.PER_HOUR:
return self.cores * 4.0/(30 * 24) + self.ram_in_gb * 4.5/(30* 24) return self.cores * 4.0/(30 * 24) + self.ram_in_gb * 4.5/(30* 24)
elif recurring_period == RecurringPeriod.PER_YEAR:
return (self.cores * 2.5 + self.ram_in_gb * 3.5) * 12
else: else:
raise Exception('Invalid recurring period for VM Product pricing.') raise Exception('Invalid recurring period for VM Product pricing.')
@ -92,7 +94,8 @@ class VMProduct(Product):
@staticmethod @staticmethod
def allowed_recurring_periods(): def allowed_recurring_periods():
return list(filter( return list(filter(
lambda pair: pair[0] in [RecurringPeriod.PER_MONTH, RecurringPeriod.PER_HOUR], lambda pair: pair[0] in [RecurringPeriod.PER_YEAR,
RecurringPeriod.PER_MONTH, RecurringPeriod.PER_HOUR],
RecurringPeriod.choices)) RecurringPeriod.choices))
class VMWithOSProduct(VMProduct): class VMWithOSProduct(VMProduct):

View file

@ -8,7 +8,7 @@ from django.utils import timezone
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from uncloud_vm.models import VMDiskImageProduct, VMDiskProduct, VMProduct, VMHost from uncloud_vm.models import VMDiskImageProduct, VMDiskProduct, VMProduct, VMHost
from uncloud_pay.models import Order from uncloud_pay.models import Order, RecurringPeriod
User = get_user_model() User = get_user_model()
cal = parsedatetime.Calendar() cal = parsedatetime.Calendar()
@ -52,31 +52,32 @@ class VMTestCase(TestCase):
creation_date=datetime.datetime.now(tz=timezone.utc), creation_date=datetime.datetime.now(tz=timezone.utc),
starting_date=datetime.datetime.now(tz=timezone.utc), starting_date=datetime.datetime.now(tz=timezone.utc),
ending_date=datetime.datetime(*one_month_later[:6], tzinfo=timezone.utc), ending_date=datetime.datetime(*one_month_later[:6], tzinfo=timezone.utc),
recurring_price=4.0, one_time_price=5.0, recurring_period='per_month' recurring_period=RecurringPeriod.PER_MONTH
) )
) )
def test_disk_product(self): # TODO: the logic tested by this test is not implemented yet.
"""Ensures that a VMDiskProduct can only be created from a VMDiskImageProduct # def test_disk_product(self):
that is in status 'active'""" # """Ensures that a VMDiskProduct can only be created from a VMDiskImageProduct
# that is in status 'active'"""
vm = self.create_sample_vm(owner=self.user) #
# vm = self.create_sample_vm(owner=self.user)
pending_disk_image = VMDiskImageProduct.objects.create( #
owner=self.user, name='pending_disk_image', is_os_image=True, is_public=True, size_in_gb=10, # pending_disk_image = VMDiskImageProduct.objects.create(
status='pending' # owner=self.user, name='pending_disk_image', is_os_image=True, is_public=True, size_in_gb=10,
) # status='pending'
try: # )
vm_disk_product = VMDiskProduct.objects.create( # try:
owner=self.user, vm=vm, image=pending_disk_image, size_in_gb=10 # vm_disk_product = VMDiskProduct.objects.create(
) # owner=self.user, vm=vm, image=pending_disk_image, size_in_gb=10
except ValidationError: # )
vm_disk_product = None # except ValidationError:
# vm_disk_product = None
self.assertIsNone( #
vm_disk_product, # self.assertIsNone(
msg='VMDiskProduct created with disk image whose status is not active.' # vm_disk_product,
) # msg='VMDiskProduct created with disk image whose status is not active.'
# )
def test_vm_disk_product_creation(self): def test_vm_disk_product_creation(self):
"""Ensure that a user can only create a VMDiskProduct for an existing VM""" """Ensure that a user can only create a VMDiskProduct for an existing VM"""
@ -94,19 +95,20 @@ class VMTestCase(TestCase):
owner=self.user, vm=vm, image=disk_image, size_in_gb=10 owner=self.user, vm=vm, image=disk_image, size_in_gb=10
) )
def test_vm_disk_product_creation_for_someone_else(self): # TODO: the logic tested by this test is not implemented yet.
"""Ensure that a user can only create a VMDiskProduct for his/her own VM""" # def test_vm_disk_product_creation_for_someone_else(self):
# """Ensure that a user can only create a VMDiskProduct for his/her own VM"""
# Create a VM which is ownership of self.user2 #
someone_else_vm = self.create_sample_vm(owner=self.user2) # # Create a VM which is ownership of self.user2
# someone_else_vm = self.create_sample_vm(owner=self.user2)
# 'self.user' would try to create a VMDiskProduct for 'user2's VM #
with self.assertRaises(ValidationError, msg='User created a VMDiskProduct for someone else VM.'): # # 'self.user' would try to create a VMDiskProduct for 'user2's VM
vm_disk_product = VMDiskProduct.objects.create( # with self.assertRaises(ValidationError, msg='User created a VMDiskProduct for someone else VM.'):
owner=self.user, vm=someone_else_vm, # vm_disk_product = VMDiskProduct.objects.create(
size_in_gb=10, # owner=self.user, vm=someone_else_vm,
image=VMDiskImageProduct.objects.create( # size_in_gb=10,
owner=self.user, name='disk_image', is_os_image=True, is_public=True, size_in_gb=10, # image=VMDiskImageProduct.objects.create(
status='active' # owner=self.user, name='disk_image', is_os_image=True, is_public=True, size_in_gb=10,
) # status='active'
) # )
# )

View file

@ -1,5 +1,6 @@
from django.db import transaction from django.db import transaction
from django.shortcuts import render from django.shortcuts import render
from django.utils import timezone
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
@ -11,10 +12,7 @@ from rest_framework.exceptions import ValidationError
from .models import VMHost, VMProduct, VMSnapshotProduct, VMDiskProduct, VMDiskImageProduct, VMCluster from .models import VMHost, VMProduct, VMSnapshotProduct, VMDiskProduct, VMDiskImageProduct, VMCluster
from uncloud_pay.models import Order from uncloud_pay.models import Order
from .serializers import (VMHostSerializer, VMProductSerializer, from .serializers import *
VMSnapshotProductSerializer, VMDiskImageProductSerializer,
VMDiskProductSerializer, DCLVMProductSerializer,
VMClusterSerializer)
from uncloud_pay.helpers import ProductViewSet from uncloud_pay.helpers import ProductViewSet
@ -121,7 +119,8 @@ class VMProductViewSet(ProductViewSet):
# Create base order. # Create base order.
order = Order.objects.create( order = Order.objects.create(
recurring_period=order_recurring_period, recurring_period=order_recurring_period,
owner=request.user owner=request.user,
starting_date=timezone.now()
) )
order.save() order.save()

View file

@ -1,4 +1,4 @@
# Generated by Django 3.0.3 on 2020-03-17 11:45 # Generated by Django 3.0.3 on 2020-03-09 07:57
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
@ -12,8 +12,8 @@ class Migration(migrations.Migration):
dependencies = [ dependencies = [
('uncloud_vm', '0003_remove_vmhost_vms'), ('uncloud_vm', '0003_remove_vmhost_vms'),
('uncloud_pay', '0002_auto_20200305_1524'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL), migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('uncloud_pay', '0001_initial'),
] ]
operations = [ operations = [