From da54a59ca25799a1bdb3182d7265533a43fcfe12 Mon Sep 17 00:00:00 2001 From: meow Date: Mon, 20 Jan 2020 12:30:12 +0500 Subject: [PATCH] initial commit --- .gitignore | 7 + README.md | 43 +++ config.py | 8 + etcd_wrapper.py | 75 +++++ helper.py | 62 +++++ ldap_manager.py | 64 +++++ products/ipv6-only-django.json | 27 ++ products/ipv6-only-vm.json | 33 +++ products/ipv6-only-vpn.json | 15 + products/membership.json | 15 + schemas.py | 134 +++++++++ stripe_utils.py | 490 +++++++++++++++++++++++++++++++++ ucloud_pay.py | 345 +++++++++++++++++++++++ 13 files changed, 1318 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 config.py create mode 100644 etcd_wrapper.py create mode 100644 helper.py create mode 100644 ldap_manager.py create mode 100644 products/ipv6-only-django.json create mode 100644 products/ipv6-only-vm.json create mode 100644 products/ipv6-only-vpn.json create mode 100644 products/membership.json create mode 100644 schemas.py create mode 100644 stripe_utils.py create mode 100644 ucloud_pay.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..77de841 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.idea/ +.vscode/ +__pycache__/ + +pay.conf +log.txt +test.py \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..1b50cf3 --- /dev/null +++ b/README.md @@ -0,0 +1,43 @@ +# uncloud-pay + +The pay module for the uncloud + +- uses [etcd3](https://coreos.com/blog/etcd3-a-new-etcd.html) for storage. +- uses [Stripe](https://stripe.com/docs/api) as the payment gateway. +- uses [ldap3](https://github.com/cannatag/ldap3) for ldap authentication. + +## Getting started + +**TODO** + +## Usage + +Currently handles very basic features, such as: + +#### 1. Adding of products +```shell script +http --json http://[::]:5000/product/add email=your_email_here password=your_password_here specs:=@ipv6-only-vm.json +``` + +#### 2. Listing of products +```shell script +http --json http://[::]:5000/product/list +``` + +#### 3. Ordering products +```shell script +http --json http://[::]:5000/product/order email=your_email_here password=your_password_here product_id=5332cb89453d495381e2b2167f32c842 cpu=1 ram=1gb os-disk-space=10gb os=alpine +``` + +#### 4. Listing users orders + +```shell script +http --json GET http://[::]:5000/order/list email=your_email_here password=your_password_here +``` + + +#### 5. Registering user's payment method (credit card for now using Stripe) + +```shell script +http --json http://[::]:5000/user/register_payment card_number=4111111111111111 cvc=123 expiry_year=2020 expiry_month=8 card_holder_name="The test user" email=your_email_here password=your_password_here +``` \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..cecbc97 --- /dev/null +++ b/config.py @@ -0,0 +1,8 @@ +import configparser +from etcd_wrapper import EtcdWrapper + + +config = configparser.ConfigParser() +config.read('pay.conf') + +etcd_client = EtcdWrapper(host=config['etcd']['host'], port=config['etcd']['port']) diff --git a/etcd_wrapper.py b/etcd_wrapper.py new file mode 100644 index 0000000..73e2c3c --- /dev/null +++ b/etcd_wrapper.py @@ -0,0 +1,75 @@ +import etcd3 +import json + +from functools import wraps + +from uncloud import UncloudException +from uncloud.common import logger + + +class EtcdEntry: + def __init__(self, meta_or_key, value, value_in_json=False): + if hasattr(meta_or_key, 'key'): + # if meta has attr 'key' then get it + self.key = meta_or_key.key.decode('utf-8') + else: + # otherwise meta is the 'key' + self.key = meta_or_key + self.value = value.decode('utf-8') + + if value_in_json: + self.value = json.loads(self.value) + + +def readable_errors(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except etcd3.exceptions.ConnectionFailedError: + raise UncloudException('Cannot connect to etcd: is etcd running as configured in uncloud.conf?') + except etcd3.exceptions.ConnectionTimeoutError as err: + raise etcd3.exceptions.ConnectionTimeoutError('etcd connection timeout.') from err + except Exception: + logger.exception('Some etcd error occured. See syslog for details.') + + return wrapper + + +class EtcdWrapper: + @readable_errors + def __init__(self, *args, **kwargs): + self.client = etcd3.client(*args, **kwargs) + + @readable_errors + def get(self, *args, value_in_json=False, **kwargs): + _value, _key = self.client.get(*args, **kwargs) + if _key is None or _value is None: + return None + return EtcdEntry(_key, _value, value_in_json=value_in_json) + + @readable_errors + def put(self, *args, value_in_json=False, **kwargs): + _key, _value = args + if value_in_json: + _value = json.dumps(_value) + + if not isinstance(_key, str): + _key = _key.decode('utf-8') + + return self.client.put(_key, _value, **kwargs) + + @readable_errors + def get_prefix(self, *args, value_in_json=False, raise_exception=True, **kwargs): + event_iterator = self.client.get_prefix(*args, **kwargs) + for e in event_iterator: + yield EtcdEntry(*e[::-1], value_in_json=value_in_json) + + @readable_errors + def watch_prefix(self, key, raise_exception=True, value_in_json=False): + event_iterator, cancel = self.client.watch_prefix(key) + for e in event_iterator: + if hasattr(e, '_event'): + e = e._event + if e.type == e.PUT: + yield EtcdEntry(e.kv.key, e.kv.value, value_in_json=value_in_json) diff --git a/helper.py b/helper.py new file mode 100644 index 0000000..c2000f5 --- /dev/null +++ b/helper.py @@ -0,0 +1,62 @@ +import config +from stripe_utils import StripeUtils + +etcd_client = config.etcd_client + + +def get_plan_id_from_product(product): + plan_id = 'ucloud-v1-' + plan_id += product['name'].strip().replace(' ', '-') + # plan_id += '-' + product['type'] + return plan_id + + +def get_order_id(): + order_id_kv = etcd_client.get('/v1/last_order_id') + if order_id_kv is not None: + order_id = int(order_id_kv.value) + 1 + else: + order_id = 0 + etcd_client.put('/v1/last_order_id', str(order_id)) + return 'OR-{}'.format(order_id) + + +def get_pricing(price_in_chf_cents, product_type, recurring_period): + if product_type == 'recurring': + return 'CHF {}/{}'.format(price_in_chf_cents/100, recurring_period) + elif product_type == 'one-time': + return 'CHF {} (One time charge)'.format(price_in_chf_cents/100) + + +def get_user_friendly_product(product_dict): + uf_product = { + 'name': product_dict['name'], + 'description': product_dict['description'], + 'product_id': product_dict['usable-id'], + 'pricing': get_pricing( + product_dict['price'], product_dict['type'], product_dict['recurring_period'] + ) + } + if product_dict['type'] == 'recurring': + uf_product['minimum_subscription_period'] = product_dict['minimum_subscription_period'] + return uf_product + + +def get_token(card_number, cvc, exp_month, exp_year): + stripe_utils = StripeUtils() + token_response = stripe_utils.get_token_from_card( + card_number, cvc, exp_month, exp_year + ) + if token_response['response_object']: + return token_response['response_object'].id + else: + return None + + +def resolve_product_usable_id(usable_id, etcd_client): + products = etcd_client.get_prefix('/v1/products/', value_in_json=True) + for p in products: + if p.value['usable-id'] == usable_id: + print(p.value['uuid'], usable_id) + return p.value['uuid'] + return None diff --git a/ldap_manager.py b/ldap_manager.py new file mode 100644 index 0000000..f8cfaa3 --- /dev/null +++ b/ldap_manager.py @@ -0,0 +1,64 @@ +import hashlib +import random +import base64 + +from ldap3 import Server, Connection, ObjectDef, Reader, ALL + + +class LdapManager: + def __init__(self, server, admin_dn, admin_password): + self.server = Server(server, get_info=ALL) + self.conn = Connection(server, admin_dn, admin_password, auto_bind=True) + self.person_obj_def = ObjectDef('inetOrgPerson', self.conn) + + def get(self, query=None, search_base='dc=ungleich,dc=ch'): + kwargs = { + 'connection': self.conn, + 'object_def': self.person_obj_def, + 'base': search_base, + } + if query: + kwargs['query'] = query + r = Reader(**kwargs) + return r.search() + + def is_password_valid(self, email, password, **kwargs): + entries = self.get(query='(mail={})'.format(email), **kwargs) + if entries: + password_in_ldap = entries[0].userPassword.value + return self._check_password(password_in_ldap, password) + return False + + @staticmethod + def _check_password(tagged_digest_salt, password): + digest_salt_b64 = tagged_digest_salt[6:] + digest_salt = base64.decodebytes(digest_salt_b64) + digest = digest_salt[:20] + salt = digest_salt[20:] + + sha = hashlib.sha1(password.encode('utf-8')) + sha.update(salt) + + return digest == sha.digest() + + @staticmethod + def ssha_password(password): + """ + Apply the SSHA password hashing scheme to the given *password*. + *password* must be a :class:`bytes` object, containing the utf-8 + encoded password. + + Return a :class:`bytes` object containing ``ascii``-compatible data + which can be used as LDAP value, e.g. after armoring it once more using + base64 or decoding it to unicode from ``ascii``. + """ + SALT_BYTES = 15 + + sha1 = hashlib.sha1() + salt = random.SystemRandom().getrandbits(SALT_BYTES * 8).to_bytes(SALT_BYTES, 'little') + sha1.update(password) + sha1.update(salt) + + digest = sha1.digest() + passwd = b'{SSHA}' + base64.b64encode(digest + salt) + return passwd diff --git a/products/ipv6-only-django.json b/products/ipv6-only-django.json new file mode 100644 index 0000000..b3d8730 --- /dev/null +++ b/products/ipv6-only-django.json @@ -0,0 +1,27 @@ +{ + "usable-id": "ipv6-only-django-hosting", + "active": true, + "name": "IPv6 Only Django Hosting", + "description": "Host your Django application on our shiny IPv6 Only VM", + "recurring_period": "month", + "features": { + "cpu": { + "unit": {"value": 1, "type":"int"}, + "price_per_unit_per_period": 3, + "one_time_fee": 0, + "constant": false + }, + "ram": { + "unit": {"value": 1, "type":"int"}, + "price_per_unit_per_period": 4, + "one_time_fee": 0, + "constant": false + }, + "os-disk-space": { + "unit": {"value": 10, "type":"int"}, + "one_time_fee": 0, + "price_per_unit_per_period": 3.5, + "constant": false + } + } +} diff --git a/products/ipv6-only-vm.json b/products/ipv6-only-vm.json new file mode 100644 index 0000000..6b21b26 --- /dev/null +++ b/products/ipv6-only-vm.json @@ -0,0 +1,33 @@ +{ + "usable-id": "ipv6-only-vm", + "active": true, + "name": "IPv6 Only VM", + "description": "IPv6 Only VM are accessible to only those having IPv6 for themselves", + "recurring_period": "month", + "features": { + "cpu": { + "unit": {"value": 1, "type":"int"}, + "price_per_unit_per_period": 3, + "one_time_fee": 0, + "constant": false + }, + "ram": { + "unit": {"value": 1, "type":"int"}, + "price_per_unit_per_period": 4, + "one_time_fee": 0, + "constant": false + }, + "os-disk-space": { + "unit": {"value": 10, "type":"int"}, + "one_time_fee": 0, + "price_per_unit_per_period": 4, + "constant": false + }, + "os": { + "unit": {"value": 1, "type":"str"}, + "one_time_fee": 0, + "price_per_unit_per_period": 0, + "constant": false + } + } +} diff --git a/products/ipv6-only-vpn.json b/products/ipv6-only-vpn.json new file mode 100644 index 0000000..43ed7bd --- /dev/null +++ b/products/ipv6-only-vpn.json @@ -0,0 +1,15 @@ +{ + "usable-id": "ipv6-only-vpn", + "active": true, + "name": "IPv6 Only VPN", + "description": "IPv6 VPN enable you to access IPv6 only websites and more", + "recurring_period": "month", + "features": { + "vpn": { + "unit": {"value": 1, "type": "int"}, + "price_per_unit_per_period": 10, + "one_time_fee": 0, + "constant": true + } + } +} diff --git a/products/membership.json b/products/membership.json new file mode 100644 index 0000000..14596fa --- /dev/null +++ b/products/membership.json @@ -0,0 +1,15 @@ +{ + "usable-id": "membership", + "active": true, + "name": "Membership", + "description": "Membership to use uncloud-pay", + "recurring_period": "eternity", + "features": { + "membership": { + "unit": {"value": 1, "type":"int"}, + "price_per_unit_per_period": 0, + "one_time_fee": 5, + "constant": true + } + } +} diff --git a/schemas.py b/schemas.py new file mode 100644 index 0000000..9d0c97f --- /dev/null +++ b/schemas.py @@ -0,0 +1,134 @@ +import logging +import config + +from helper import resolve_product_usable_id + +etcd_client = config.etcd_client + + +class ValidationException(Exception): + """Validation Error""" + + +class Field: + def __init__(self, _name, _type, _value=None, validators=None): + if validators is None: + validators = [] + + assert isinstance(validators, list) + + self.name = _name + self.value = _value + self.type = _type + self.validators = validators + + def is_valid(self): + if not isinstance(self.value, self.type): + try: + self.value = self.type(self.value) + except Exception: + raise ValidationException("Incorrect Type for '{}' field".format(self.name)) + + for validator in self.validators: + validator() + + def __repr__(self): + return self.name + + +class BaseSchema: + def __init__(self): + self.fields = [getattr(self, field) for field in dir(self) if isinstance(getattr(self, field), Field)] + + def validation(self): + # custom validation is optional + return True + + def is_valid(self): + for field in self.fields: + field.is_valid() + + for parent in self.__class__.__bases__: + parent.validation(self) + + self.validation() + + for field in self.fields: + setattr(self, field.name, field.value) + + def return_data(self): + return { + field.name: field.value + for field in self.fields + } + + +def get(dictionary: dict, key: str, return_default=False, default=None): + if dictionary is None: + raise ValidationException('No data provided at all.') + try: + value = dictionary[key] + except KeyError: + if return_default: + return default + raise ValidationException("Missing data for '{}' field.".format(key)) + else: + return value + + +class AddProductSchema(BaseSchema): + def __init__(self, data): + self.email = Field('email', str, get(data, 'email')) + self.password = Field('password', str, get(data, 'password')) + self.specs = Field('specs', dict, get(data, 'specs')) + super().__init__() + + +class UserRegisterPaymentSchema(BaseSchema): + def __init__(self, data): + self.email = Field('email', str, get(data, 'email')) + self.password = Field('password', str, get(data, 'password')) + self.card_number = Field('card_number', str, get(data, 'card_number')) + self.cvc = Field('cvc', str, get(data, 'cvc')) + self.expiry_year = Field('expiry_year', int, get(data, 'expiry_year')) + self.expiry_month = Field('expiry_month', int, get(data, 'expiry_month')) + self.card_holder_name = Field('card_holder_name', str, get(data, 'card_holder_name')) + + super().__init__() + + +class ProductOrderSchema(BaseSchema): + def __init__(self, data): + self.email = Field('email', str, get(data, 'email')) + self.password = Field('password', str, get(data, 'password')) + self.product_id = Field('product_id', str, get(data, 'product_id'), validators=[self.product_id_validation]) + + super().__init__() + + def product_id_validation(self): + product_uuid = resolve_product_usable_id(self.product_id.value, etcd_client) + if product_uuid: + self.product_id.value = product_uuid + else: + raise ValidationException('Invalid Product ID') + + +class OrderListSchema(BaseSchema): + def __init__(self, data): + self.email = Field('email', str, get(data, 'email')) + self.password = Field('password', str, get(data, 'password')) + super().__init__() + +def make_return_message(err, status_code=200): + logging.debug('message: {}'.format(str(err))) + return {'message': str(err)}, status_code + + +def create_schema(specification, data): + fields = {} + for feature_name, feature_detail in specification['features'].items(): + if not feature_detail['constant']: + fields[feature_name] = Field(feature_name, eval(feature_detail['unit']['type']), get(data, feature_name)) + + return type('{}Schema'.format(specification['name']), (BaseSchema,), fields) + diff --git a/stripe_utils.py b/stripe_utils.py new file mode 100644 index 0000000..5ffb443 --- /dev/null +++ b/stripe_utils.py @@ -0,0 +1,490 @@ +import json +import re +import stripe +import stripe.error +import logging + +from config import etcd_client as client, config as config + +stripe.api_key = config['stripe']['private_key'] + + +def handle_stripe_error(f): + def handle_problems(*args, **kwargs): + response = { + 'paid': False, + 'response_object': None, + 'error': None + } + + common_message = "Currently it's not possible to make payments." + try: + response_object = f(*args, **kwargs) + response = { + 'response_object': response_object, + 'error': None + } + return response + except stripe.error.CardError as e: + # Since it's a decline, stripe.error.CardError will be caught + body = e.json_body + err = body['error'] + response.update({'error': err['message']}) + logging.error(str(e)) + return response + except stripe.error.RateLimitError: + response.update( + {'error': "Too many requests made to the API too quickly"}) + return response + except stripe.error.InvalidRequestError as e: + logging.error(str(e)) + response.update({'error': "Invalid parameters"}) + return response + except stripe.error.AuthenticationError as e: + # Authentication with Stripe's API failed + # (maybe you changed API keys recently) + logging.error(str(e)) + response.update({'error': common_message}) + return response + except stripe.error.APIConnectionError as e: + logging.error(str(e)) + response.update({'error': common_message}) + return response + except stripe.error.StripeError as e: + # maybe send email + logging.error(str(e)) + response.update({'error': common_message}) + return response + except Exception as e: + # maybe send email + logging.error(str(e)) + response.update({'error': common_message}) + return response + + return handle_problems + + +class StripeUtils(object): + CURRENCY = 'chf' + INTERVAL = 'month' + SUCCEEDED_STATUS = 'succeeded' + STRIPE_PLAN_ALREADY_EXISTS = 'Plan already exists' + STRIPE_NO_SUCH_PLAN = 'No such plan' + PLAN_EXISTS_ERROR_MSG = 'Plan {} exists already.\nCreating a local StripePlan now.' + PLAN_DOES_NOT_EXIST_ERROR_MSG = 'Plan {} does not exist.' + + def __init__(self): + self.stripe = stripe + + @handle_stripe_error + def card_exists(self, customer, cc_number, exp_month, exp_year, cvc): + token_obj = stripe.Token.create( + card={ + 'number': cc_number, + 'exp_month': exp_month, + 'exp_year': exp_year, + 'cvc': cvc, + }, + ) + cards = stripe.Customer.list_sources( + customer, + limit=20, + object='card' + ) + + for card in cards.data: + if (card.fingerprint == token_obj.card.fingerprint and + int(card.exp_month) == int(exp_month) and int(card.exp_year) == int(exp_year)): + return True + return False + + @staticmethod + def get_stripe_customer_from_email(email): + customer = stripe.Customer.list(limit=1, email=email) + return customer.data[0] if len(customer.data) == 1 else None + + @staticmethod + def update_customer_token(customer, token): + customer.source = token + customer.save() + + @handle_stripe_error + def get_token_from_card(self, cc_number, cvc, expiry_month, expiry_year): + token_obj = stripe.Token.create( + card={ + 'number': cc_number, + 'exp_month': expiry_month, + 'exp_year': expiry_year, + 'cvc': cvc, + }, + ) + return token_obj + + @handle_stripe_error + def associate_customer_card(self, stripe_customer_id, token, + set_as_default=False): + customer = stripe.Customer.retrieve(stripe_customer_id) + card = customer.sources.create(source=token) + if set_as_default: + customer.default_source = card.id + customer.save() + return True + + @handle_stripe_error + def dissociate_customer_card(self, stripe_customer_id, card_id): + customer = stripe.Customer.retrieve(stripe_customer_id) + card = customer.sources.retrieve(card_id) + card.delete() + + @handle_stripe_error + def update_customer_card(self, customer_id, token): + customer = stripe.Customer.retrieve(customer_id) + current_card_token = customer.default_source + customer.sources.retrieve(current_card_token).delete() + customer.source = token + customer.save() + credit_card_raw_data = customer.sources.data.pop() + new_card_data = { + 'last4': credit_card_raw_data.last4, + 'brand': credit_card_raw_data.brand + } + return new_card_data + + @handle_stripe_error + def get_card_details(self, customer_id): + customer = stripe.Customer.retrieve(customer_id) + credit_card_raw_data = customer.sources.data.pop() + card_details = { + 'last4': credit_card_raw_data.last4, + 'brand': credit_card_raw_data.brand, + 'exp_month': credit_card_raw_data.exp_month, + 'exp_year': credit_card_raw_data.exp_year, + 'fingerprint': credit_card_raw_data.fingerprint, + 'card_id': credit_card_raw_data.id + } + return card_details + + @handle_stripe_error + def get_all_invoices(self, customer_id, created_gt): + return_list = [] + has_more_invoices = True + starting_after = False + while has_more_invoices: + if starting_after: + invoices = stripe.Invoice.list( + limit=10, customer=customer_id, created={'gt': created_gt}, + starting_after=starting_after + ) + else: + invoices = stripe.Invoice.list( + limit=10, customer=customer_id, created={'gt': created_gt} + ) + has_more_invoices = invoices.has_more + for invoice in invoices.data: + sub_ids = [] + for line in invoice.lines.data: + if line.type == 'subscription': + sub_ids.append(line.id) + elif line.type == 'invoiceitem': + sub_ids.append(line.subscription) + else: + sub_ids.append('') + invoice_details = { + 'created': invoice.created, + 'receipt_number': invoice.receipt_number, + 'invoice_number': invoice.number, + 'paid_at': invoice.status_transitions.paid_at if invoice.paid else 0, + 'period_start': invoice.period_start, + 'period_end': invoice.period_end, + 'billing_reason': invoice.billing_reason, + 'discount': invoice.discount.coupon.amount_off if invoice.discount else 0, + 'total': invoice.total, + # to see how many line items we have in this invoice and + # then later check if we have more than 1 + 'lines_data_count': len(invoice.lines.data) if invoice.lines.data is not None else 0, + 'invoice_id': invoice.id, + 'lines_meta_data_csv': ','.join( + [line.metadata.VM_ID if hasattr(line.metadata, 'VM_ID') else '' for line in invoice.lines.data] + ), + 'subscription_ids_csv': ','.join(sub_ids), + 'line_items': invoice.lines.data + } + starting_after = invoice.id + return_list.append(invoice_details) + return return_list + + @handle_stripe_error + def get_cards_details_from_token(self, token): + stripe_token = stripe.Token.retrieve(token) + card_details = { + 'last4': stripe_token.card.last4, + 'brand': stripe_token.card.brand, + 'exp_month': stripe_token.card.exp_month, + 'exp_year': stripe_token.card.exp_year, + 'fingerprint': stripe_token.card.fingerprint, + 'card_id': stripe_token.card.id + } + return card_details + + def check_customer(self, stripe_cus_api_id, user, token): + try: + customer = stripe.Customer.retrieve(stripe_cus_api_id) + except stripe.error.InvalidRequestError: + customer = self.create_customer(token, user.email, user.name) + user.stripecustomer.stripe_id = customer.get( + 'response_object').get('id') + user.stripecustomer.save() + if type(customer) is dict: + customer = customer['response_object'] + return customer + + @handle_stripe_error + def get_customer(self, stripe_api_cus_id): + customer = stripe.Customer.retrieve(stripe_api_cus_id) + # data = customer.get('response_object') + return customer + + @handle_stripe_error + def create_customer(self, token, email, name=None): + if name is None or name.strip() == "": + name = email + customer = self.stripe.Customer.create( + source=token, + description=name, + email=email + ) + return customer + + @handle_stripe_error + def make_charge(self, amount=None, customer=None): + _amount = float(amount) + amount = int(_amount * 100) # stripe amount unit, in cents + charge = self.stripe.Charge.create( + amount=amount, # in cents + currency=self.CURRENCY, + customer=customer + ) + return charge + + @staticmethod + def _get_all_stripe_plans(): + all_stripe_plans = client.get("/v1/stripe_plans") + all_stripe_plans_set = set() + if all_stripe_plans: + all_stripe_plans_obj = json.loads(all_stripe_plans.value) + if all_stripe_plans_obj and len(all_stripe_plans_obj['plans']) > 0: + all_stripe_plans_set = set(all_stripe_plans_obj["plans"]) + return all_stripe_plans_set + + @staticmethod + def _save_all_stripe_plans(stripe_plans): + client.put("/v1/stripe_plans", json.dumps({"plans": list(stripe_plans)})) + + @handle_stripe_error + def get_or_create_stripe_plan(self, product_name, amount, stripe_plan_id, + interval=INTERVAL): + """ + This function checks if a StripePlan with the given + stripe_plan_id already exists. If it exists then the function + returns this object otherwise it creates a new StripePlan and + returns the new object. + + :param amount: The amount in CHF cents + :param name: The name of the Stripe plan to be created. + :param stripe_plan_id: The id of the Stripe plan to be + created. Use get_stripe_plan_id_string function to + obtain the name of the plan to be created + :param interval: The interval for subscription {month, year}. Defaults + to month if not provided + :return: The StripePlan object if it exists else creates a + Plan object in Stripe and a local StripePlan and + returns it. Returns None in case of Stripe error + """ + _amount = float(amount) + amount = int(_amount * 100) # stripe amount unit, in cents + all_stripe_plans = self._get_all_stripe_plans() + if stripe_plan_id in all_stripe_plans: + logging.debug("{} plan exists in db.".format(stripe_plan_id)) + else: + logging.debug(("{} plan DOES NOT exist in db. " + "Creating").format(stripe_plan_id)) + try: + plan_obj = self.stripe.Plan.retrieve(id=stripe_plan_id) + logging.debug("{} plan exists in Stripe".format(stripe_plan_id)) + all_stripe_plans.add(stripe_plan_id) + except stripe.error.InvalidRequestError as e: + if "No such plan" in str(e): + logging.debug("Plan {} does not exist in Stripe, Creating") + plan_obj = self.stripe.Plan.create( + amount=amount, + product={'name': product_name}, + interval=interval, + currency=self.CURRENCY, + id=stripe_plan_id) + logging.debug(self.PLAN_EXISTS_ERROR_MSG.format(stripe_plan_id)) + all_stripe_plans.add(stripe_plan_id) + self._save_all_stripe_plans(all_stripe_plans) + return stripe_plan_id + + @handle_stripe_error + def delete_stripe_plan(self, stripe_plan_id): + """ + Deletes the Plan in Stripe and also deletes the local db copy + of the plan if it exists + + :param stripe_plan_id: The stripe plan id that needs to be + deleted + :return: True if the plan was deleted successfully from + Stripe, False otherwise. + """ + return_value = False + try: + plan = self.stripe.Plan.retrieve(stripe_plan_id) + plan.delete() + return_value = True + all_stripe_plans = self._get_all_stripe_plans() + all_stripe_plans.remove(stripe_plan_id) + self._save_all_stripe_plans(all_stripe_plans) + except stripe.error.InvalidRequestError as e: + if self.STRIPE_NO_SUCH_PLAN in str(e): + logging.debug( + self.PLAN_DOES_NOT_EXIST_ERROR_MSG.format(stripe_plan_id)) + return return_value + + @handle_stripe_error + def subscribe_customer_to_plan(self, customer, plans, trial_end=None): + """ + Subscribes the given customer to the list of given plans + + :param customer: The stripe customer identifier + :param plans: A list of stripe plans. + :param trial_end: An integer representing when the Stripe subscription + is supposed to end + Ref: https://stripe.com/docs/api/python#create_subscription-items + e.g. + plans = [ + { + "plan": "dcl-v1-cpu-2-ram-5gb-ssd-10gb", + }, + ] + :return: The subscription StripeObject + """ + + subscription_result = self.stripe.Subscription.create( + customer=customer, items=plans, trial_end=trial_end + ) + return subscription_result + + @handle_stripe_error + def set_subscription_metadata(self, subscription_id, metadata): + subscription = stripe.Subscription.retrieve(subscription_id) + subscription.metadata = metadata + subscription.save() + + @handle_stripe_error + def unsubscribe_customer(self, subscription_id): + """ + Cancels a given subscription + + :param subscription_id: The Stripe subscription id string + :return: + """ + sub = stripe.Subscription.retrieve(subscription_id) + return sub.delete() + + @handle_stripe_error + def make_payment(self, customer, amount, token): + charge = self.stripe.Charge.create( + amount=amount, # in cents + currency=self.CURRENCY, + customer=customer + ) + return charge + + @staticmethod + def get_stripe_plan_id(cpu, ram, ssd, version, app='dcl', hdd=None, + price=None): + """ + Returns the Stripe plan id string of the form + `dcl-v1-cpu-2-ram-5gb-ssd-10gb` based on the input parameters + + :param cpu: The number of cores + :param ram: The size of the RAM in GB + :param ssd: The size of ssd storage in GB + :param hdd: The size of hdd storage in GB + :param version: The version of the Stripe plans + :param app: The application to which the stripe plan belongs + to. By default it is 'dcl' + :param price: The price for this plan + :return: A string of the form `dcl-v1-cpu-2-ram-5gb-ssd-10gb` + """ + dcl_plan_string = 'cpu-{cpu}-ram-{ram}gb-ssd-{ssd}gb'.format(cpu=cpu, + ram=ram, + ssd=ssd) + if hdd is not None: + dcl_plan_string = '{dcl_plan_string}-hdd-{hdd}gb'.format( + dcl_plan_string=dcl_plan_string, hdd=hdd) + stripe_plan_id_string = '{app}-v{version}-{plan}'.format( + app=app, + version=version, + plan=dcl_plan_string + ) + if price is not None: + stripe_plan_id_string_with_price = '{}-{}chf'.format( + stripe_plan_id_string, + round(price, 2) + ) + return stripe_plan_id_string_with_price + else: + return stripe_plan_id_string + + @staticmethod + def get_vm_config_from_stripe_id(stripe_id): + """ + Given a string like "dcl-v1-cpu-2-ram-5gb-ssd-10gb" return different + configuration params as a dict + + :param stripe_id|str + :return: dict + """ + pattern = re.compile(r'^dcl-v(\d+)-cpu-(\d+)-ram-(\d+\.?\d*)gb-ssd-(\d+)gb-?(\d*\.?\d*)(chf)?$') + match_res = pattern.match(stripe_id) + if match_res is not None: + price = None + try: + price = match_res.group(5) + except IndexError: + logging.debug("Did not find price in {}".format(stripe_id)) + return { + 'version': match_res.group(1), + 'cores': match_res.group(2), + 'ram': match_res.group(3), + 'ssd': match_res.group(4), + 'price': price + } + + @staticmethod + def get_stripe_plan_name(cpu, memory, disk_size, price): + """ + Returns the Stripe plan name + :return: + """ + return "{cpu} Cores, {memory} GB RAM, {disk_size} GB SSD, " \ + "{price} CHF".format( + cpu=cpu, + memory=memory, + disk_size=disk_size, + price=round(price, 2) + ) + + @handle_stripe_error + def set_subscription_meta_data(self, subscription_id, meta_data): + """ + Adds VM metadata to a subscription + :param subscription_id: Stripe identifier for the subscription + :param meta_data: A dict of meta data to be added + :return: + """ + subscription = stripe.Subscription.retrieve(subscription_id) + subscription.metadata = meta_data + subscription.save() diff --git a/ucloud_pay.py b/ucloud_pay.py new file mode 100644 index 0000000..edee113 --- /dev/null +++ b/ucloud_pay.py @@ -0,0 +1,345 @@ +import json +import time +import logging + +from datetime import datetime +from uuid import uuid4 + +from flask import Flask, request +from flask_restful import Resource, Api + +from config import etcd_client as client, config as config +from stripe_utils import StripeUtils +from ldap_manager import LdapManager +from schemas import ( + make_return_message, ValidationException, UserRegisterPaymentSchema, + AddProductSchema, ProductOrderSchema, OrderListSchema, create_schema +) +from helper import ( + get_plan_id_from_product, get_user_friendly_product, get_order_id, +) + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) +log_formater = logging.Formatter('[%(filename)s:%(lineno)d] %(message)s') + +stream_logger = logging.StreamHandler() +stream_logger.setLevel(logging.DEBUG) +stream_logger.setFormatter(log_formater) + +logger.addHandler(stream_logger) + +app = Flask(__name__) +api = Api(app) +INIT_ORDER_ID = 0 + +ldap_manager = LdapManager(server=config['ldap']['server'], admin_dn=config['ldap']['admin_dn'], + admin_password=config['ldap']['admin_password']) + + +def calculate_charges(specification, data): + one_time_charge = 0 + recurring_charge = 0 + for feature_name, feature_detail in specification['features'].items(): + if feature_detail['constant']: + data[feature_name] = 1 + + if feature_detail['unit']['type'] != 'str': + one_time_charge += feature_detail['one_time_fee'] + recurring_charge += ( + feature_detail['price_per_unit_per_period'] * data[feature_name] / + feature_detail['unit']['value'] + ) + return one_time_charge, recurring_charge + + +class ListProducts(Resource): + @staticmethod + def get(): + products = client.get_prefix('/v1/products/', value_in_json=False) + prod_dict = {} + for p in products: + p = json.loads(p.value) + prod_dict[p['usable-id']] = { + 'name': p['name'], + 'description': p['description'], + 'active': p['active'] + } + logger.debug('Products = {}'.format(prod_dict)) + return prod_dict, 200 + + +class AddProduct(Resource): + @staticmethod + def post(): + data = request.json + logger.debug('Got data: {}'.format(str(data))) + + try: + validator = AddProductSchema(data) + validator.is_valid() + except ValidationException as err: + return make_return_message(err, 400) + else: + if ldap_manager.is_password_valid(data['email'], data['password']): + try: + user = ldap_manager.get('(mail={})'.format(data['email']))[0] + user = json.loads(user.entry_to_json()) + uid, ou, *dc = user['dn'].replace('ou=', '').replace('dc=', '').replace('uid=', '').split(',') + except Exception as err: + logger.error(str(err)) + return {'message': 'No such user exists'} + else: + if ou != config['ldap']['internal_user_ou']: + logger.error('User (email=%s) does not have access to create product', validator.email) + return {'message': 'Forbidden'}, 403 + else: + product_uuid = uuid4().hex + product_key = '/v1/products/{}'.format(product_uuid) + product_value = validator.specs + product_value['uuid'] = product_uuid + + logger.debug('Adding product data: {}'.format(str(product_value))) + client.put(product_key, product_value, value_in_json=True) + return {'message': 'Product created'}, 200 + + else: + return {'message': 'Wrong Credentials'}, 403 + + +class UserRegisterPayment(Resource): + @staticmethod + def post(): + data = request.json + logger.debug('Got data: {}'.format(str(data))) + try: + validator = UserRegisterPaymentSchema(data) + validator.is_valid() + except ValidationException as err: + return make_return_message(err, 400) + else: + last4 = data['card_number'].strip()[-4:] + + if ldap_manager.is_password_valid(validator.email, validator.password): + stripe_utils = StripeUtils() + + # Does customer already exist ? + stripe_customer = stripe_utils.get_stripe_customer_from_email(validator.email) + + # Does customer already exist ? + if stripe_customer is not None: + logger.debug('Customer {} exists already'.format(validator.email)) + + # Check if the card already exists + ce_response = stripe_utils.card_exists( + stripe_customer.id, cc_number=data['card_number'], + exp_month=int(data['expiry_month']), + exp_year=int(data['expiry_year']), + cvc=data['cvc']) + + if ce_response['response_object']: + message = 'The given card ending in {} exists already.'.format(last4) + return make_return_message(message, 400) + + elif ce_response['response_object'] is False: + # Associate card with user + logger.debug('Adding card ending in {}'.format(last4)) + token_response = stripe_utils.get_token_from_card( + data['card_number'], data['cvc'], data['expiry_month'], + data['expiry_year'] + ) + if token_response['response_object']: + logger.debug('Token {}'.format(token_response['response_object'].id)) + resp = stripe_utils.associate_customer_card( + stripe_customer.id, token_response['response_object'].id + ) + if resp['response_object']: + return make_return_message( + 'Card ending in {} registered as your payment source'.format(last4) + ) + else: + return make_return_message('Error with payment gateway. Contact support', 400) + else: + return make_return_message('Error: {}'.format(ce_response['error']), 400) + else: + # Stripe customer does not exist, create a new one + logger.debug('Customer {} does not exist, creating new'.format(validator.email)) + token_response = stripe_utils.get_token_from_card( + validator.card_number, validator.cvc, validator.expiry_month, + validator.expiry_year + ) + if token_response['response_object']: + logger.debug('Token {}'.format(token_response['response_object'].id)) + + # Create stripe customer + stripe_customer_resp = stripe_utils.create_customer( + name=validator.card_holder_name, + token=token_response['response_object'].id, + email=validator.email + ) + stripe_customer = stripe_customer_resp['response_object'] + + if stripe_customer: + logger.debug('Created stripe customer {}'.format(stripe_customer.id)) + return make_return_message( + 'Card ending in {} registered as your payment source'.format(last4) + ) + else: + return make_return_message('Error with card. Contact support', 400) + else: + return make_return_message('Error with payment gateway. Contact support', 400) + else: + return make_return_message('Wrong Credentials', 403) + + +class ProductOrder(Resource): + @staticmethod + def post(): + data = request.json + try: + validator = ProductOrderSchema(data) + validator.is_valid() + except ValidationException as err: + return make_return_message(err, 400) + else: + if ldap_manager.is_password_valid(validator.email, validator.password): + stripe_utils = StripeUtils() + logger.debug('Product ID = {}'.format(validator.product_id)) + + # Validate the given product is ok + product = client.get('/v1/products/{}'.format(validator.product_id), value_in_json=True) + if not product: + return make_return_message('Invalid Product', 400) + + product = product.value + + customer_previous_orders = client.get_prefix( + '/v1/user/{}'.format(validator.email), value_in_json=True + ) + membership = next(filter(lambda o: o.value['product'] == 'membership', customer_previous_orders), None) + if membership is None and data['product_id'] != 'membership': + return make_return_message('Please buy membership first to use this facility') + + logger.debug('Got product {}'.format(product)) + + # Check the user has a payment source added + stripe_customer = stripe_utils.get_stripe_customer_from_email(validator.email) + + if not stripe_customer or len(stripe_customer.sources) == 0: + return make_return_message('Please register first.', 400) + + try: + product_schema = create_schema(product, data) + product_schema = product_schema() + product_schema.is_valid() + except ValidationException as err: + return make_return_message(err, 400) + else: + transformed_data = product_schema.return_data() + logger.debug('Tranformed data: {}'.format(transformed_data)) + one_time_charge, recurring_charge = calculate_charges(product, transformed_data) + recurring_charge = int(recurring_charge) + + # Initiate a one-time/subscription based on product type + if recurring_charge > 0: + logger.debug('Product {} is recurring payment'.format(product['name'])) + plan_id = get_plan_id_from_product(product) + res = stripe_utils.get_or_create_stripe_plan( + product_name=product['name'], + stripe_plan_id=plan_id, amount=recurring_charge, + interval=product['recurring_period'], + ) + if res['response_object']: + logger.debug('Obtained plan {}'.format(plan_id)) + subscription_res = stripe_utils.subscribe_customer_to_plan( + stripe_customer.id, + [{'plan': plan_id}] + ) + subscription_obj = subscription_res['response_object'] + if subscription_obj is None or subscription_obj.status != 'active': + return make_return_message( + 'Error subscribing to plan. Detail: {}'.format(subscription_res['error']), 400 + ) + else: + order_obj = { + 'order_id': get_order_id(), + 'ordered_at': int(time.time()), + 'product': product['usable-id'], + } + client.put('/v1/user/{}/orders'.format(validator.email), order_obj, value_in_json=True) + order_obj['ordered_at'] = datetime.fromtimestamp(order_obj['ordered_at']).strftime('%c') + return make_return_message('Order Successful. Order Details: {}'.format(order_obj)) + else: + logger.error('Could not create plan {}'.format(plan_id)) + + elif recurring_charge == 0 and one_time_charge > 0: + logger.debug('Product {} is one-time payment'.format(product['name'])) + charge_response = stripe_utils.make_charge( + amount=one_time_charge, + customer=stripe_customer.id + ) + stripe_onetime_charge = charge_response.get('response_object') + + # Check if the payment was approved + if not stripe_onetime_charge: + msg = charge_response.get('error') + return make_return_message( + 'Error subscribing to plan. Details: {}'.format(msg), 400 + ) + + order_obj = { + 'order_id': get_order_id(), + 'ordered_at': int(time.time()), + 'product': product['usable-id'], + } + client.put( + '/v1/user/{}/orders'.format(validator.email),order_obj, + value_in_json=True + ) + order_obj['ordered_at'] = datetime.fromtimestamp(order_obj['ordered_at']).strftime('%c') + return {'message': 'Order successful', 'order_details': order_obj}, 200 + else: + return make_return_message('Wrong Credentials', 400) + + +class OrderList(Resource): + @staticmethod + def get(): + data = request.json + try: + validator = OrderListSchema(data) + validator.is_valid() + except ValidationException as err: + return make_return_message(err, 400) + else: + print(validator.email, validator.password) + if not ldap_manager.is_password_valid(validator.email, validator.password): + return {'message': 'Wrong Credentials'}, 403 + + orders = client.get_prefix('/v1/user/{}/orders'.format(validator.email), value_in_json=True) + orders_dict = { + order.value['order_id']: { + 'ordered-at': datetime.fromtimestamp(order.value['ordered_at']).strftime('%c'), + 'product': order.value['product'] + } + for order in orders + } + # for p in orders: + # order_dict = p.value + # order_dict['ordered_at'] = datetime.fromtimestamp( + # order_dict['ordered_at']).strftime('%c') + # order_dict['product'] = order_dict['product']['name'] + # orders_dict[order_dict['order_id']] = order_dict + logger.debug('Orders = {}'.format(orders_dict)) + return orders_dict, 200 + + +api.add_resource(ListProducts, '/product/list') +api.add_resource(AddProduct, '/product/add') +api.add_resource(ProductOrder, '/product/order') +api.add_resource(UserRegisterPayment, '/user/register_payment') +api.add_resource(OrderList, '/order/list') + + +if __name__ == '__main__': + app.run(host='::', port=config['app']['port'], debug=True) \ No newline at end of file