Merge remote-tracking branch 'origin/fnux-hacks'

This commit is contained in:
fnux 2020-03-04 12:10:22 +01:00
commit 371c5ccf00
14 changed files with 359 additions and 220 deletions

View file

@ -2,32 +2,9 @@ from functools import reduce
from datetime import datetime from datetime import datetime
from rest_framework import mixins from rest_framework import mixins
from rest_framework.viewsets import GenericViewSet from rest_framework.viewsets import GenericViewSet
from django.db.models import Q
from .models import Bill, Payment, PaymentMethod, Order
from django.utils import timezone from django.utils import timezone
from django.core.exceptions import ObjectDoesNotExist
from calendar import monthrange from calendar import monthrange
def get_balance_for(user):
bills = reduce(
lambda acc, entry: acc + entry.total,
Bill.objects.filter(owner=user),
0)
payments = reduce(
lambda acc, entry: acc + entry.amount,
Payment.objects.filter(owner=user),
0)
return payments - bills
def get_payment_method_for(user):
methods = PaymentMethod.objects.filter(owner=user)
for method in methods:
# Do we want to do something with non-primary method?
if method.primary:
return method
return None
def beginning_of_month(year, month): def beginning_of_month(year, month):
tz = timezone.get_current_timezone() tz = timezone.get_current_timezone()
return datetime(year=year, month=month, day=1, tzinfo=tz) return datetime(year=year, month=month, day=1, tzinfo=tz)
@ -38,53 +15,6 @@ def end_of_month(year, month):
return datetime(year=year, month=month, day=days, return datetime(year=year, month=month, day=days,
hour=23, minute=59, second=59, tzinfo=tz) hour=23, minute=59, second=59, tzinfo=tz)
def generate_bills_for(year, month, user, allowed_delay):
# /!\ We exclusively work on the specified year and month.
# Default values for next bill (if any). Only saved at the end of
# this method, if relevant.
next_bill = Bill(owner=user,
starting_date=beginning_of_month(year, month),
ending_date=end_of_month(year, month),
creation_date=timezone.now(),
due_date=timezone.now() + allowed_delay)
# Select all orders active on the request period.
orders = Order.objects.filter(
Q(ending_date__gt=next_bill.starting_date) | Q(ending_date__isnull=True),
owner=user)
# Check if there is already a bill covering the order and period pair:
# * Get latest bill by ending_date: previous_bill.ending_date
# * If previous_bill.ending_date is before next_bill.ending_date, a new
# bill has to be generated.
unpaid_orders = []
for order in orders:
try:
previous_bill = order.bill.latest('ending_date')
except ObjectDoesNotExist:
previous_bill = None
if previous_bill == None or previous_bill.ending_date < next_bill.ending_date:
unpaid_orders.append(order)
# Commit next_bill if it there are 'unpaid' orders.
if len(unpaid_orders) > 0:
next_bill.save()
# It is not possible to register many-to-many relationship before
# the two end-objects are saved in database.
for order in unpaid_orders:
order.bill.add(next_bill)
# TODO: use logger.
print("Generated bill {} (amount: {}) for user {}."
.format(next_bill.uuid, next_bill.total, user))
return next_bill
# Return None if no bill was created.
class ProductViewSet(mixins.CreateModelMixin, class ProductViewSet(mixins.CreateModelMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.ListModelMixin, mixins.ListModelMixin,

View file

@ -1,7 +1,6 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from uncloud_auth.models import User from uncloud_auth.models import User
from uncloud_pay.models import Order, Bill from uncloud_pay.models import Order, Bill, PaymentMethod, get_balance_for
from uncloud_pay.helpers import get_balance_for, get_payment_method_for
from datetime import timedelta from datetime import timedelta
from django.utils import timezone from django.utils import timezone
@ -19,7 +18,7 @@ class Command(BaseCommand):
balance = get_balance_for(user) balance = get_balance_for(user)
if balance < 0: if balance < 0:
print("User {} has negative balance ({}), charging.".format(user.username, balance)) print("User {} has negative balance ({}), charging.".format(user.username, balance))
payment_method = get_payment_method_for(user) payment_method = PaymentMethod.get_primary_for(user)
if payment_method != None: if payment_method != None:
amount_to_be_charged = abs(balance) amount_to_be_charged = abs(balance)
charge_ok = payment_method.charge(amount_to_be_charged) charge_ok = payment_method.charge(amount_to_be_charged)

View file

@ -7,7 +7,7 @@ from django.core.exceptions import ObjectDoesNotExist
from datetime import timedelta, date from datetime import timedelta, date
from django.utils import timezone from django.utils import timezone
from uncloud_pay.helpers import generate_bills_for from uncloud_pay.models import Bill
BILL_PAYMENT_DELAY=timedelta(days=10) BILL_PAYMENT_DELAY=timedelta(days=10)
@ -28,7 +28,7 @@ class Command(BaseCommand):
for user in users: for user in users:
now = timezone.now() now = timezone.now()
generate_bills_for( Bill.generate_for(
year=now.year, year=now.year,
month=now.month, month=now.month,
user=user, user=user,

View file

@ -1,13 +1,12 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from uncloud_auth.models import User from uncloud_auth.models import User
from uncloud_pay.models import Order, Bill from uncloud_pay.models import Bill
from uncloud_pay.helpers import get_balance_for, get_payment_method_for
from datetime import timedelta from datetime import timedelta
from django.utils import timezone from django.utils import timezone
class Command(BaseCommand): class Command(BaseCommand):
help = 'Generate bills and charge customers if necessary.' help = 'Take action on overdue bills.'
def add_arguments(self, parser): def add_arguments(self, parser):
pass pass
@ -16,28 +15,9 @@ class Command(BaseCommand):
users = User.objects.all() users = User.objects.all()
print("Processing {} users.".format(users.count())) print("Processing {} users.".format(users.count()))
for user in users: for user in users:
balance = get_balance_for(user) for bill in Bill.get_overdue_for(user):
if balance < 0: print("/!\ Overdue bill for {}, {} with amount {}"
print("User {} has negative balance ({}), checking for overdue bills." .format(user.username, bill.uuid, bill.amount))
.format(user.username, balance)) # TODO: take action?
# Get bills DESCENDING by creation date (= latest at top).
bills = Bill.objects.filter(
owner=user,
due_date__lt=timezone.now()
).order_by('-creation_date')
overdue_balance = abs(balance)
overdue_bills = []
for bill in bills:
if overdue_balance < 0:
break # XXX: I'm (fnux) not fond of breaks!
overdue_balance -= bill.amount
overdue_bills.append(bill)
for bill in overdue_bills:
print("/!\ Overdue bill for {}, {} with amount {}"
.format(user.username, bill.uuid, bill.amount))
# TODO: take action?
print("=> Done.") print("=> Done.")

View file

@ -0,0 +1,18 @@
# Generated by Django 3.0.3 on 2020-03-03 10:27
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('uncloud_pay', '0001_initial'),
]
operations = [
migrations.RenameField(
model_name='orderrecord',
old_name='setup_fee',
new_name='one_time_price',
),
]

View file

@ -0,0 +1,24 @@
# Generated by Django 3.0.3 on 2020-03-03 13:56
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('uncloud_pay', '0014_auto_20200303_1027'),
]
operations = [
migrations.CreateModel(
name='StripeCustomer',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('stripe_id', models.CharField(max_length=32)),
('owner', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
),
]

View file

@ -0,0 +1,25 @@
# Generated by Django 3.0.3 on 2020-03-03 15:52
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('uncloud_pay', '0015_stripecustomer'),
]
operations = [
migrations.RemoveField(
model_name='stripecustomer',
name='id',
),
migrations.AlterField(
model_name='stripecustomer',
name='owner',
field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to=settings.AUTH_USER_MODEL),
),
]

View file

@ -1,19 +1,31 @@
from django.db import models from django.db import models
from functools import reduce from django.db.models import Q
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.validators import MinValueValidator from django.core.validators import MinValueValidator
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.utils import timezone from django.utils import timezone
from django.dispatch import receiver
from django.core.exceptions import ObjectDoesNotExist
import django.db.models.signals as signals
import uuid
from functools import reduce
from math import ceil from math import ceil
from datetime import timedelta from datetime import timedelta
from calendar import monthrange from calendar import monthrange
import uuid import uncloud_pay.stripe
from uncloud_pay.helpers import beginning_of_month, end_of_month
from decimal import Decimal
# Define DecimalField properties, used to represent amounts of money. # Define DecimalField properties, used to represent amounts of money.
AMOUNT_MAX_DIGITS=10 AMOUNT_MAX_DIGITS=10
AMOUNT_DECIMALS=2 AMOUNT_DECIMALS=2
# Used to generate bill due dates.
BILL_PAYMENT_DELAY=timedelta(days=10)
# See https://docs.djangoproject.com/en/dev/ref/models/fields/#field-choices-enum-types # See https://docs.djangoproject.com/en/dev/ref/models/fields/#field-choices-enum-types
class RecurringPeriod(models.TextChoices): class RecurringPeriod(models.TextChoices):
ONE_TIME = 'ONCE', _('Onetime') ONE_TIME = 'ONCE', _('Onetime')
@ -24,6 +36,34 @@ class RecurringPeriod(models.TextChoices):
PER_HOUR = 'HOUR', _('Per Hour') PER_HOUR = 'HOUR', _('Per Hour')
PER_SECOND = 'SECOND', _('Per Second') PER_SECOND = 'SECOND', _('Per Second')
# See https://docs.djangoproject.com/en/dev/ref/models/fields/#field-choices-enum-types
class ProductStatus(models.TextChoices):
PENDING = 'PENDING', _('Pending')
AWAITING_PAYMENT = 'AWAITING_PAYMENT', _('Awaiting payment')
BEING_CREATED = 'BEING_CREATED', _('Being created')
ACTIVE = 'ACTIVE', _('Active')
DELETED = 'DELETED', _('Deleted')
###
# Users.
def get_balance_for(user):
bills = reduce(
lambda acc, entry: acc + entry.total,
Bill.objects.filter(owner=user),
0)
payments = reduce(
lambda acc, entry: acc + entry.amount,
Payment.objects.filter(owner=user),
0)
return payments - bills
class StripeCustomer(models.Model):
owner = models.OneToOneField( get_user_model(),
primary_key=True,
on_delete=models.CASCADE)
stripe_id = models.CharField(max_length=32)
### ###
# Payments and Payment Methods. # Payments and Payment Methods.
@ -50,6 +90,20 @@ class Payment(models.Model):
default='unknown') default='unknown')
timestamp = models.DateTimeField(editable=False, auto_now_add=True) timestamp = models.DateTimeField(editable=False, auto_now_add=True)
# WIP prepaid and service activation logic by fnux.
## We override save() in order to active products awaiting payment.
#def save(self, *args, **kwargs):
# # TODO: only run activation logic on creation, not on update.
# unpaid_bills_before_payment = Bill.get_unpaid_for(self.owner)
# super(Payment, self).save(*args, **kwargs) # Save payment in DB.
# unpaid_bills_after_payment = Bill.get_unpaid_for(self.owner)
# newly_paid_bills = list(
# set(unpaid_bills_before_payment) - set(unpaid_bills_after_payment))
# for bill in newly_paid_bills:
# bill.activate_orders()
class PaymentMethod(models.Model): class PaymentMethod(models.Model):
uuid = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) uuid = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
owner = models.ForeignKey(get_user_model(), owner = models.ForeignKey(get_user_model(),
@ -67,24 +121,50 @@ class PaymentMethod(models.Model):
# Only used for "Stripe" source # Only used for "Stripe" source
stripe_card_id = models.CharField(max_length=32, blank=True, null=True) stripe_card_id = models.CharField(max_length=32, blank=True, null=True)
@property
def stripe_card_last4(self):
if self.source == 'stripe':
card_request = uncloud_pay.stripe.get_card(
StripeCustomer.objects.get(owner=self.owner).stripe_id,
self.stripe_card_id)
if card_request['error'] == None:
return card_request['response_object']['last4']
else:
return None
else:
return None
def charge(self, amount): def charge(self, amount):
if amount > 0: # Make sure we don't charge negative amount by errors... if amount > 0: # Make sure we don't charge negative amount by errors...
if self.source == 'stripe': if self.source == 'stripe':
# TODO: wire to stripe, see meooow-payv1/strip_utils.py stripe_customer = StripeCustomer.objects.get(owner=self.owner).stripe_id
payment = Payment(owner=self.owner, source=self.source, amount=amount) charge_request = uncloud_pay.stripe.charge_customer(amount, stripe_customer, self.stripe_card_id)
payment.save() # TODO: Check return status if charge_request['error'] == None:
payment = Payment(owner=self.owner, source=self.source, amount=amount)
payment.save() # TODO: Check return status
return True return payment
else:
raise Exception('Stripe error: {}'.format(charge_request['error']))
else: else:
# We do not handle that source yet. raise Exception('This payment method is unsupported/cannot be charged.')
return False
else: else:
return False raise Exception('Cannot charge negative amount.')
def get_primary_for(user):
methods = PaymentMethod.objects.filter(owner=user)
for method in methods:
# Do we want to do something with non-primary method?
if method.primary:
return method
return None
class Meta: class Meta:
unique_together = [['owner', 'primary']] unique_together = [['owner', 'primary']]
### ###
# Bills & Payments. # Bills & Payments.
@ -113,13 +193,91 @@ class Bill(models.Model):
@property @property
def total(self): def total(self):
return reduce(lambda acc, record: acc + record.amount(), self.records, 0) return reduce(lambda acc, record: acc + record.amount, self.records, 0)
@property @property
def final(self): def final(self):
# A bill is final when its ending date is passed. # A bill is final when its ending date is passed.
return self.ending_date < timezone.now() return self.ending_date < timezone.now()
@staticmethod
def generate_for(year, month, user):
# /!\ We exclusively work on the specified year and month.
# Default values for next bill (if any). Only saved at the end of
# this method, if relevant.
next_bill = Bill(owner=user,
starting_date=beginning_of_month(year, month),
ending_date=end_of_month(year, month),
creation_date=timezone.now(),
due_date=timezone.now() + BILL_PAYMENT_DELAY)
# Select all orders active on the request period.
orders = Order.objects.filter(
Q(ending_date__gt=next_bill.starting_date) | Q(ending_date__isnull=True),
owner=user)
# Check if there is already a bill covering the order and period pair:
# * Get latest bill by ending_date: previous_bill.ending_date
# * If previous_bill.ending_date is before next_bill.ending_date, a new
# bill has to be generated.
unpaid_orders = []
for order in orders:
try:
previous_bill = order.bill.latest('ending_date')
except ObjectDoesNotExist:
previous_bill = None
if previous_bill == None or previous_bill.ending_date < next_bill.ending_date:
unpaid_orders.append(order)
# Commit next_bill if it there are 'unpaid' orders.
if len(unpaid_orders) > 0:
next_bill.save()
# It is not possible to register many-to-many relationship before
# the two end-objects are saved in database.
for order in unpaid_orders:
order.bill.add(next_bill)
# TODO: use logger.
print("Generated bill {} (amount: {}) for user {}."
.format(next_bill.uuid, next_bill.total, user))
return next_bill
# Return None if no bill was created.
return None
@staticmethod
def get_unpaid_for(user):
balance = get_balance_for(user)
unpaid_bills = []
# No unpaid bill if balance is positive.
if balance >= 0:
return []
else:
bills = Bill.objects.filter(
owner=user,
due_date__lt=timezone.now()
).order_by('-creation_date')
# Amount to be paid by the customer.
unpaid_balance = abs(balance)
for bill in bills:
if unpaid_balance < 0:
break
unpaid_balance -= bill.amount
unpaid_bills.append(bill)
return unpaid_bills
@staticmethod
def get_overdue_for(user):
unpaid_bills = Bill.get_unpaid_for(user)
return list(filter(lambda bill: bill.due_date > timezone.now(), unpaid_bills))
class BillRecord(): class BillRecord():
""" """
Entry of a bill, dynamically generated from order records. Entry of a bill, dynamically generated from order records.
@ -128,12 +286,17 @@ class BillRecord():
def __init__(self, bill, order_record): def __init__(self, bill, order_record):
self.bill = bill self.bill = bill
self.order = order_record.order self.order = order_record.order
self.setup_fee = order_record.setup_fee
self.recurring_price = order_record.recurring_price self.recurring_price = order_record.recurring_price
self.recurring_period = order_record.recurring_period self.recurring_period = order_record.recurring_period
self.description = order_record.description self.description = order_record.description
def amount(self): if self.order.starting_date > self.bill.starting_date:
self.one_time_price = order_record.one_time_price
else:
self.one_time_price = 0
@property
def recurring_count(self):
# Compute billing delta. # Compute billing delta.
billed_until = self.bill.ending_date billed_until = self.bill.ending_date
if self.order.ending_date != None and self.order.ending_date < self.order.ending_date: if self.order.ending_date != None and self.order.ending_date < self.order.ending_date:
@ -166,39 +329,31 @@ class BillRecord():
(_, days_in_month) = monthrange( (_, days_in_month) = monthrange(
self.bill.starting_date.year, self.bill.starting_date.year,
self.bill.starting_date.month) self.bill.starting_date.month)
adjusted_recurring_price = self.recurring_price / days_in_month return Decimal(days / days_in_month)
recurring_price = adjusted_recurring_price * days
return self.recurring_price # TODO
elif self.recurring_period == RecurringPeriod.PER_DAY: elif self.recurring_period == RecurringPeriod.PER_DAY:
days = ceil(billed_delta / timedelta(days=1)) days = ceil(billed_delta / timedelta(days=1))
return self.recurring_price * days return Decimal(days)
elif self.recurring_period == RecurringPeriod.PER_HOUR: elif self.recurring_period == RecurringPeriod.PER_HOUR:
hours = ceil(billed_delta / timedelta(hours=1)) hours = ceil(billed_delta / timedelta(hours=1))
return self.recurring_price * hours return Decimal(hours)
elif self.recurring_period == RecurringPeriod.PER_SECOND: elif self.recurring_period == RecurringPeriod.PER_SECOND:
seconds = ceil(billed_delta / timedelta(seconds=1)) seconds = ceil(billed_delta / timedelta(seconds=1))
return self.recurring_price * seconds return Decimal(seconds)
elif self.recurring_period == RecurringPeriod.ONE_TIME:
return Decimal(0)
else: else:
raise Exception('Unsupported recurring period: {}.'. raise Exception('Unsupported recurring period: {}.'.
format(record.recurring_period)) format(record.recurring_period))
@property
def amount(self):
return self.recurring_price * self.recurring_count + self.one_time_price
### ###
# Orders. # Orders.
# /!\ BIG FAT WARNING /!\ #
#
# Order are assumed IMMUTABLE and used as SOURCE OF TRUST for generating # Order are assumed IMMUTABLE and used as SOURCE OF TRUST for generating
# bills. Do **NOT** mutate then! # bills. Do **NOT** mutate then!
#
# Why? We need to store the state somewhere since product are mutable (e.g.
# adding RAM to VM, changing price of 1GB of RAM, ...). An alternative could
# have been to only store the state in bills but would have been more
# confusing: the order is a 'contract' with the customer, were both parts
# agree on deal => That's what we want to keep archived.
#
# /!\ BIG FAT WARNING /!\ #
class Order(models.Model): class Order(models.Model):
uuid = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) uuid = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
owner = models.ForeignKey(get_user_model(), owner = models.ForeignKey(get_user_model(),
@ -224,22 +379,31 @@ class Order(models.Model):
return OrderRecord.objects.filter(order=self) return OrderRecord.objects.filter(order=self)
@property @property
def setup_fee(self): def one_time_price(self):
return reduce(lambda acc, record: acc + record.setup_fee, self.records, 0) return reduce(lambda acc, record: acc + record.one_time_price, self.records, 0)
@property @property
def recurring_price(self): def recurring_price(self):
return reduce(lambda acc, record: acc + record.recurring_price, self.records, 0) return reduce(lambda acc, record: acc + record.recurring_price, self.records, 0)
def add_record(self, setup_fee, recurring_price, description): def add_record(self, one_time_price, recurring_price, description):
OrderRecord.objects.create(order=self, OrderRecord.objects.create(order=self,
setup_fee=setup_fee, one_time_price=one_time_price,
recurring_price=recurring_price, recurring_price=recurring_price,
description=description) description=description)
class OrderRecord(models.Model): class OrderRecord(models.Model):
"""
Order records store billing informations for products: the actual product
might be mutated and/or moved to another order but we do not want to loose
the details of old orders.
Used as source of trust to dynamically generate bill entries.
"""
order = models.ForeignKey(Order, on_delete=models.CASCADE) order = models.ForeignKey(Order, on_delete=models.CASCADE)
setup_fee = models.DecimalField(default=0.0, one_time_price = models.DecimalField(default=0.0,
max_digits=AMOUNT_MAX_DIGITS, max_digits=AMOUNT_MAX_DIGITS,
decimal_places=AMOUNT_DECIMALS, decimal_places=AMOUNT_DECIMALS,
validators=[MinValueValidator(0)]) validators=[MinValueValidator(0)])
@ -276,15 +440,9 @@ class Product(models.Model):
description = "" description = ""
status = models.CharField(max_length=256, status = models.CharField(max_length=32,
choices = ( choices=ProductStatus.choices,
('pending', 'Pending'), default=ProductStatus.PENDING)
('being_created', 'Being created'),
('active', 'Active'),
('deleted', 'Deleted')
),
default='pending'
)
order = models.ForeignKey(Order, order = models.ForeignKey(Order,
on_delete=models.CASCADE, on_delete=models.CASCADE,
@ -296,7 +454,7 @@ class Product(models.Model):
pass # To be implemented in child. pass # To be implemented in child.
@property @property
def setup_fee(self): def one_time_price(self):
return 0 return 0
@property @property

View file

@ -1,13 +1,6 @@
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from rest_framework import serializers from rest_framework import serializers
from .models import * from .models import *
from .helpers import get_balance_for
from functools import reduce
from uncloud_vm.serializers import VMProductSerializer
from uncloud_vm.models import VMProduct
import uncloud_pay.stripe as stripe
### ###
# Users. # Users.
@ -19,8 +12,6 @@ class UserSerializer(serializers.ModelSerializer):
# Display current 'balance' # Display current 'balance'
balance = serializers.SerializerMethodField('get_balance') balance = serializers.SerializerMethodField('get_balance')
def __sum_balance(self, entries):
return reduce(lambda acc, entry: acc + entry.amount, entries, 0)
def get_balance(self, user): def get_balance(self, user):
return get_balance_for(user) return get_balance_for(user)
@ -34,9 +25,14 @@ class PaymentSerializer(serializers.ModelSerializer):
fields = ['owner', 'amount', 'source', 'timestamp'] fields = ['owner', 'amount', 'source', 'timestamp']
class PaymentMethodSerializer(serializers.ModelSerializer): class PaymentMethodSerializer(serializers.ModelSerializer):
stripe_card_last4 = serializers.IntegerField()
class Meta: class Meta:
model = PaymentMethod model = PaymentMethod
fields = ['source', 'description', 'primary'] fields = ['uuid', 'source', 'description', 'primary', 'stripe_card_last4']
class ChargePaymentMethodSerializer(serializers.Serializer):
amount = serializers.DecimalField(max_digits=10, decimal_places=2)
class CreditCardSerializer(serializers.Serializer): class CreditCardSerializer(serializers.Serializer):
number = serializers.IntegerField() number = serializers.IntegerField()
@ -51,41 +47,6 @@ class CreatePaymentMethodSerializer(serializers.ModelSerializer):
model = PaymentMethod model = PaymentMethod
fields = ['source', 'description', 'primary', 'credit_card'] fields = ['source', 'description', 'primary', 'credit_card']
def create(self, validated_data):
credit_card = stripe.CreditCard(**validated_data.pop('credit_card'))
user = self.context['request'].user
customer = stripe.create_customer(user.username, user.email)
# TODO check customer error
customer_id = customer['response_object']['id']
stripe_card = stripe.create_card(customer_id, credit_card)
# TODO: check credit card error
validated_data['stripe_card_id'] = stripe_card['response_object']['id']
class CreditCardSerializer(serializers.Serializer):
number = serializers.IntegerField()
exp_month = serializers.IntegerField()
exp_year = serializers.IntegerField()
cvc = serializers.IntegerField()
class CreatePaymentMethodSerializer(serializers.ModelSerializer):
credit_card = CreditCardSerializer()
class Meta:
model = PaymentMethod
fields = ['source', 'description', 'primary', 'credit_card']
def create(self, validated_data):
credit_card = stripe.CreditCard(**validated_data.pop('credit_card'))
user = self.context['request'].user
customer = stripe.create_customer(user.username, user.email)
# TODO check customer error
customer_id = customer['response_object']['id']
stripe_card = stripe.create_card(customer_id, credit_card)
# TODO: check credit card error
validated_data['stripe_card_id'] = stripe_card['response_object']['id']
payment_method = PaymentMethod.objects.create(**validated_data)
return payment_method
payment_method = PaymentMethod.objects.create(**validated_data)
return payment_method
### ###
# Bills # Bills
@ -96,6 +57,8 @@ class BillRecordSerializer(serializers.Serializer):
description = serializers.CharField() description = serializers.CharField()
recurring_period = serializers.CharField() recurring_period = serializers.CharField()
recurring_price = serializers.DecimalField(max_digits=10, decimal_places=2) recurring_price = serializers.DecimalField(max_digits=10, decimal_places=2)
recurring_count = serializers.DecimalField(max_digits=10, decimal_places=2)
one_time_price = serializers.DecimalField(max_digits=10, decimal_places=2)
amount = serializers.DecimalField(max_digits=10, decimal_places=2) amount = serializers.DecimalField(max_digits=10, decimal_places=2)
class BillSerializer(serializers.ModelSerializer): class BillSerializer(serializers.ModelSerializer):
@ -111,7 +74,7 @@ class BillSerializer(serializers.ModelSerializer):
class OrderRecordSerializer(serializers.ModelSerializer): class OrderRecordSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = OrderRecord model = OrderRecord
fields = ['setup_fee', 'recurring_price', 'description'] fields = ['one_time_price', 'recurring_price', 'description']
class OrderSerializer(serializers.ModelSerializer): class OrderSerializer(serializers.ModelSerializer):
@ -119,7 +82,4 @@ class OrderSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Order model = Order
fields = ['uuid', 'creation_date', 'starting_date', 'ending_date', fields = ['uuid', 'creation_date', 'starting_date', 'ending_date',
'bill', 'recurring_period', 'records', 'recurring_price', 'setup_fee'] 'bill', 'recurring_period', 'records', 'recurring_price', 'one_time_price']
class ProductSerializer(serializers.Serializer):
vms = VMProductSerializer(many=True, read_only=True)

View file

@ -2,6 +2,9 @@ import stripe
import stripe.error import stripe.error
import logging import logging
from django.core.exceptions import ObjectDoesNotExist
import uncloud_pay.models
import uncloud.secrets import uncloud.secrets
# Static stripe configuration used below. # Static stripe configuration used below.
@ -18,7 +21,7 @@ def handle_stripe_error(f):
'error': None 'error': None
} }
common_message = "Currently it's not possible to make payments." common_message = "Currently it is not possible to make payments."
try: try:
response_object = f(*args, **kwargs) response_object = f(*args, **kwargs)
response = { response = {
@ -79,11 +82,24 @@ class CreditCard():
# Actual Stripe logic. # Actual Stripe logic.
def get_customer_id_for(user):
try:
# .get() raise if there is no matching entry.
return uncloud_pay.models.StripeCustomer.objects.get(owner=user).stripe_id
except ObjectDoesNotExist:
# No entry yet - making a new one.
customer_request = create_customer(user.username, user.email)
if customer_request['error'] == None:
mapping = uncloud_pay.models.StripeCustomer.objects.create(
owner=user,
stripe_id=customer_request['response_object']['id']
)
return mapping.stripe_id
else:
return None
@handle_stripe_error @handle_stripe_error
def create_card(customer_id, credit_card): def create_card(customer_id, credit_card):
# Test settings
credit_card.number = "5555555555554444"
return stripe.Customer.create_source( return stripe.Customer.create_source(
customer_id, customer_id,
card={ card={
@ -95,14 +111,18 @@ def create_card(customer_id, credit_card):
@handle_stripe_error @handle_stripe_error
def get_card(customer_id, card_id): def get_card(customer_id, card_id):
return stripe.Card.retrieve_source(customer_id, card_id) return stripe.Customer.retrieve_source(customer_id, card_id)
@handle_stripe_error @handle_stripe_error
def charge_customer(amount, source): def charge_customer(amount, customer_id, card_id):
# Amount is in CHF but stripes requires smallest possible unit.
# See https://stripe.com/docs/api/charges/create
adjusted_amount = int(amount * 100)
return stripe.Charge.create( return stripe.Charge.create(
amount=amount, amount=adjusted_amount,
currenty=CURRENCY, currency=CURRENCY,
source=source) customer=customer_id,
source=card_id)
@handle_stripe_error @handle_stripe_error
def create_customer(name, email): def create_customer(name, email):

View file

@ -1,4 +1,5 @@
from django.shortcuts import render from django.shortcuts import render
from django.db import transaction
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from rest_framework import viewsets, permissions, status from rest_framework import viewsets, permissions, status
from rest_framework.response import Response from rest_framework.response import Response
@ -9,6 +10,7 @@ import json
from .models import * from .models import *
from .serializers import * from .serializers import *
from datetime import datetime from datetime import datetime
import uncloud_pay.stripe as uncloud_stripe
### ###
# Standard user views: # Standard user views:
@ -62,6 +64,8 @@ class PaymentMethodViewSet(viewsets.ModelViewSet):
def get_serializer_class(self): def get_serializer_class(self):
if self.action == 'create': if self.action == 'create':
return CreatePaymentMethodSerializer return CreatePaymentMethodSerializer
elif self.action == 'charge':
return ChargePaymentMethodSerializer
else: else:
return PaymentMethodSerializer return PaymentMethodSerializer
@ -69,26 +73,47 @@ class PaymentMethodViewSet(viewsets.ModelViewSet):
def get_queryset(self): def get_queryset(self):
return PaymentMethod.objects.filter(owner=self.request.user) return PaymentMethod.objects.filter(owner=self.request.user)
# XXX: Handling of errors is far from great down there.
@transaction.atomic
def create(self, request): def create(self, request):
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
serializer.save(owner=request.user)
headers = self.get_success_headers(serializer.data) # Retrieve Stripe customer ID for user.
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) customer_id = uncloud_stripe.get_customer_id_for(request.user)
if customer_id == None:
return Response(
{'error': 'Could not resolve customer stripe ID.'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Register card under stripe customer.
credit_card = uncloud_stripe.CreditCard(**serializer.validated_data.pop('credit_card'))
card_request = uncloud_stripe.create_card(customer_id, credit_card)
if card_request['error']:
return Response({'stripe_error': card_request['error']}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
card_id = card_request['response_object']['id']
# Save payment method locally.
serializer.validated_data['stripe_card_id'] = card_request['response_object']['id']
payment_method = PaymentMethod.objects.create(owner=request.user, **serializer.validated_data)
# We do not want to return the credit card details sent with the POST
# request.
output_serializer = PaymentMethodSerializer(payment_method)
return Response(output_serializer.data)
# TODO: find a way to customize serializer for actions.
# drf-action-serializer module seems to do that.
@action(detail=True, methods=['post']) @action(detail=True, methods=['post'])
def charge(self, request, pk=None): def charge(self, request, pk=None):
payment_method = self.get_object() payment_method = self.get_object()
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
amount = serializer.data['amount'] amount = serializer.validated_data['amount']
if payment_method.charge(amount): try:
return Response({'charged', amount}) payment = payment_method.charge(amount)
else: output_serializer = PaymentSerializer(payment)
return Response(status=status.HTTP_500_INTERNAL_ERROR) return Response(output_serializer.data)
except Exception as e:
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
### ###
# Admin views. # Admin views.

View file

@ -106,7 +106,7 @@ class VMProductViewSet(ProductViewSet):
# Add Product record to order (VM is mutable, allows to keep history in order). # Add Product record to order (VM is mutable, allows to keep history in order).
# XXX: Move this to some kind of on_create hook in parent Product class? # XXX: Move this to some kind of on_create hook in parent Product class?
order.add_record(vm.setup_fee, order.add_record(vm.one_time_price,
vm.recurring_price(order.recurring_period), vm.description) vm.recurring_price(order.recurring_period), vm.description)
return Response(serializer.data) return Response(serializer.data)

View file

@ -28,5 +28,5 @@ class MatrixServiceProduct(Product):
RecurringPeriod.choices)) RecurringPeriod.choices))
@property @property
def setup_fee(self): def one_time_price(self):
return 30 return 30

View file

@ -41,7 +41,7 @@ class MatrixServiceProductViewSet(ProductViewSet):
# XXX: Move this to some kind of on_create hook in parent # XXX: Move this to some kind of on_create hook in parent
# Product class? # Product class?
order.add_record( order.add_record(
vm.setup_fee, vm.one_time_price,
vm.recurring_price(order.recurring_period), vm.recurring_price(order.recurring_period),
vm.description) vm.description)
@ -54,7 +54,7 @@ class MatrixServiceProductViewSet(ProductViewSet):
# XXX: Move this to some kind of on_create hook in parent # XXX: Move this to some kind of on_create hook in parent
# Product class? # Product class?
order.add_record( order.add_record(
service.setup_fee, service.one_time_price,
service.recurring_price(order.recurring_period), service.recurring_price(order.recurring_period),
service.description) service.description)