diff --git a/datacenterlight/tests.py b/datacenterlight/tests.py index 7c2f7353..c34c56ba 100644 --- a/datacenterlight/tests.py +++ b/datacenterlight/tests.py @@ -115,8 +115,8 @@ class CeleryTaskTestCase(TestCase): 'response_object').stripe_plan_id}]) stripe_subscription_obj = subscription_result.get('response_object') # Check if the subscription was approved and is active - if stripe_subscription_obj is None or \ - stripe_subscription_obj.status != 'active': + if stripe_subscription_obj is None \ + or stripe_subscription_obj.status != 'active': msg = subscription_result.get('error') raise Exception("Creating subscription failed: {}".format(msg)) diff --git a/dynamicweb/settings/base.py b/dynamicweb/settings/base.py index 08ce457d..29533211 100644 --- a/dynamicweb/settings/base.py +++ b/dynamicweb/settings/base.py @@ -559,7 +559,7 @@ CELERY_RESULT_BACKEND = env('CELERY_RESULT_BACKEND') CELERY_ACCEPT_CONTENT = ['application/json'] CELERY_TASK_SERIALIZER = 'json' CELERY_RESULT_SERIALIZER = 'json' -CELERY_TIMEZONE = 'Europe/Zurich' +#CELERY_TIMEZONE = 'Europe/Zurich' CELERY_MAX_RETRIES = int_env('CELERY_MAX_RETRIES', 5) ENABLE_DEBUG_LOGGING = bool_env('ENABLE_DEBUG_LOGGING') @@ -585,6 +585,9 @@ if ENABLE_DEBUG_LOGGING: }, } +TEST_MANAGE_SSH_KEY_PUBKEY = env('TEST_MANAGE_SSH_KEY_PUBKEY') +TEST_MANAGE_SSH_KEY_HOST = env('TEST_MANAGE_SSH_KEY_HOST') + DEBUG = bool_env('DEBUG') if DEBUG: diff --git a/hosting/forms.py b/hosting/forms.py index 288a8caf..056d0004 100644 --- a/hosting/forms.py +++ b/hosting/forms.py @@ -1,16 +1,22 @@ import datetime +import logging +import subprocess +import tempfile from django import forms -from membership.models import CustomUser from django.contrib.auth import authenticate - from django.utils.translation import ugettext_lazy as _ +from membership.models import CustomUser +from utils.hosting_utils import get_all_public_keys from .models import UserHostingKey +logger = logging.getLogger(__name__) + def generate_ssh_key_name(): - return 'dcl-generated-key-' + datetime.datetime.now().strftime('%m%d%y%H%M') + return 'dcl-generated-key-' + datetime.datetime.now().strftime( + '%m%d%y%H%M') class HostingUserLoginForm(forms.Form): @@ -38,9 +44,7 @@ class HostingUserLoginForm(forms.Form): CustomUser.objects.get(email=email) return email except CustomUser.DoesNotExist: - raise forms.ValidationError("User does not exist") - else: - return email + raise forms.ValidationError(_("User does not exist")) class HostingUserSignupForm(forms.ModelForm): @@ -51,7 +55,8 @@ class HostingUserSignupForm(forms.ModelForm): model = CustomUser fields = ['name', 'email', 'password'] widgets = { - 'name': forms.TextInput(attrs={'placeholder': 'Enter your name or company name'}), + 'name': forms.TextInput( + attrs={'placeholder': 'Enter your name or company name'}), } def clean_confirm_password(self): @@ -65,19 +70,55 @@ class HostingUserSignupForm(forms.ModelForm): class UserHostingKeyForm(forms.ModelForm): private_key = forms.CharField(widget=forms.HiddenInput(), required=False) public_key = forms.CharField(widget=forms.Textarea( - attrs={'class': 'form_public_key', 'placeholder': _('Paste here your public key')}), + attrs={'class': 'form_public_key', + 'placeholder': _('Paste here your public key')}), required=False, ) user = forms.models.ModelChoiceField(queryset=CustomUser.objects.all(), - required=False, widget=forms.HiddenInput()) + required=False, + widget=forms.HiddenInput()) name = forms.CharField(required=False, widget=forms.TextInput( - attrs={'class': 'form_key_name', 'placeholder': _('Give a name to your key')})) + attrs={'class': 'form_key_name', + 'placeholder': _('Give a name to your key')})) def __init__(self, *args, **kwargs): self.request = kwargs.pop("request") super(UserHostingKeyForm, self).__init__(*args, **kwargs) self.fields['name'].label = _('Key name') + def clean_public_key(self): + """ + Validates a public ssh key using `ssh-keygen -lf key.pub` + Also checks if a given key already exists in the database and + alerts the user of it. + :return: + """ + if 'generate' in self.request.POST: + return self.data.get('public_key') + KEY_ERROR_MESSAGE = _("Please input a proper SSH key") + openssh_pubkey_str = self.data.get('public_key').strip() + + if openssh_pubkey_str in get_all_public_keys(self.request.user): + key_name = UserHostingKey.objects.filter( + user_id=self.request.user.id, + public_key=openssh_pubkey_str).first().name + KEY_EXISTS_MESSAGE = _( + "This key exists already with the name \"%(name)s\"") % { + 'name': key_name} + raise forms.ValidationError(KEY_EXISTS_MESSAGE) + + with tempfile.NamedTemporaryFile(delete=True) as tmp_public_key_file: + tmp_public_key_file.write(openssh_pubkey_str.encode('utf-8')) + tmp_public_key_file.flush() + try: + subprocess.check_output( + ['ssh-keygen', '-lf', tmp_public_key_file.name]) + except subprocess.CalledProcessError as cpe: + logger.debug( + "Not a correct ssh format {error}".format(error=str(cpe))) + raise forms.ValidationError(KEY_ERROR_MESSAGE) + return openssh_pubkey_str + def clean_name(self): return self.data.get('name') diff --git a/hosting/locale/de/LC_MESSAGES/django.po b/hosting/locale/de/LC_MESSAGES/django.po index 2f6cee4e..1c0f3faa 100644 --- a/hosting/locale/de/LC_MESSAGES/django.po +++ b/hosting/locale/de/LC_MESSAGES/django.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2017-09-14 02:55+0530\n" +"POT-Creation-Date: 2017-09-14 12:27+0000\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -24,6 +24,9 @@ msgstr "Dein Benutzername und/oder Dein Passwort ist falsch." msgid "Your account is not activated yet." msgstr "Dein Account wurde noch nicht aktiviert." +msgid "User does not exist" +msgstr "Der Benutzer existiert nicht" + msgid "Paste here your public key" msgstr "Füge deinen Public Key ein" @@ -33,6 +36,13 @@ msgstr "Gebe deinem SSH-Key einen Name" msgid "Key name" msgstr "Key-Name" +msgid "Please input a proper SSH key" +msgstr "Bitte verwende einen gültigen SSH-Key" + +#, python-format +msgid "This key exists already with the name \"%(name)s\"" +msgstr "Der SSH-Key mit dem Name \"%(name)s\" existiert bereits" + msgid "All Rights Reserved" msgstr "Alle Rechte vorbehalten" diff --git a/hosting/views.py b/hosting/views.py index 290d5f0f..ed85dbdb 100644 --- a/hosting/views.py +++ b/hosting/views.py @@ -1,3 +1,4 @@ +import logging import uuid import json from time import sleep @@ -18,7 +19,7 @@ from django.views.generic import View, CreateView, FormView, ListView, \ DetailView, \ DeleteView, TemplateView, UpdateView from guardian.mixins import PermissionRequiredMixin -from oca.pool import WrongNameError, WrongIdError +from oca.pool import WrongIdError from stored_messages.api import mark_read from stored_messages.models import Message from stored_messages.settings import stored_messages_settings @@ -29,6 +30,7 @@ from opennebula_api.serializers import VirtualMachineSerializer, \ VirtualMachineTemplateSerializer from utils.forms import BillingAddressForm, PasswordResetRequestForm, \ UserBillingAddressForm +from utils.hosting_utils import get_all_public_keys from utils.mailer import BaseEmail from utils.stripe_utils import StripeUtils from utils.views import PasswordResetViewMixin, PasswordResetConfirmViewMixin, \ @@ -38,8 +40,11 @@ from .forms import HostingUserSignupForm, HostingUserLoginForm, \ from .mixins import ProcessVMSelectionMixin from .models import HostingOrder, HostingBill, HostingPlan, UserHostingKey -CONNECTION_ERROR = "Your VMs cannot be displayed at the moment due to a backend \ - connection error. please try again in a few minutes." +logger = logging.getLogger(__name__) + +CONNECTION_ERROR = "Your VMs cannot be displayed at the moment due to a \ + backend connection error. please try again in a few \ + minutes." class DashboardView(View): @@ -370,17 +375,14 @@ class SSHKeyDeleteView(LoginRequiredMixin, DeleteView): def delete(self, request, *args, **kwargs): owner = self.request.user - manager = OpenNebulaManager() + manager = OpenNebulaManager( + email=owner.email, + password=owner.password + ) pk = self.kwargs.get('pk') # Get user ssh key public_key = UserHostingKey.objects.get(pk=pk).public_key - # Add ssh key to user - try: - manager.remove_public_key(user=owner, public_key=public_key) - except ConnectionError: - pass - except WrongNameError: - pass + manager.manage_public_key([{'value': public_key, 'state': False}]) return super(SSHKeyDeleteView, self).delete(request, *args, **kwargs) @@ -421,6 +423,13 @@ class SSHKeyChoiceView(LoginRequiredMixin, View): user=request.user, public_key=public_key, name=name) filename = name + '_' + str(uuid.uuid4())[:8] + '_private.pem' ssh_key.private_key.save(filename, content) + owner = self.request.user + manager = OpenNebulaManager( + email=owner.email, + password=owner.password + ) + public_key_str = public_key.decode() + manager.manage_public_key([{'value': public_key_str, 'state': True}]) return redirect(reverse_lazy('hosting:ssh_keys'), foo='bar') @@ -465,23 +474,17 @@ class SSHKeyCreateView(LoginRequiredMixin, FormView): }) owner = self.request.user - manager = OpenNebulaManager() - - # Get user ssh key - public_key = str(form.cleaned_data.get('public_key', '')) - # Add ssh key to user - try: - manager.add_public_key( - user=owner, public_key=public_key, merge=True) - except ConnectionError: - pass - except WrongNameError: - pass - + manager = OpenNebulaManager( + email=owner.email, + password=owner.password + ) + public_key = form.cleaned_data['public_key'] + if type(public_key) is bytes: + public_key = public_key.decode() + manager.manage_public_key([{'value': public_key, 'state': True}]) return HttpResponseRedirect(self.success_url) def post(self, request, *args, **kwargs): - print(self.request.POST.dict()) form = self.get_form() required = 'add_ssh' in self.request.POST form.fields['name'].required = required @@ -662,16 +665,12 @@ class PaymentVMView(LoginRequiredMixin, FormView): 'form': form }) return render(request, self.template_name, context) - # For now just get first one - user_key = UserHostingKey.objects.filter( - user=self.request.user).first() # Create a vm using logged user vm_id = manager.create_vm( template_id=vm_template_id, - # XXX: Confi specs=specs, - ssh_key=user_key.public_key, + ssh_key=settings.ONEADMIN_USER_SSH_PUBLIC_KEY, ) # Create a Hosting Order @@ -725,6 +724,19 @@ class PaymentVMView(LoginRequiredMixin, FormView): email = BaseEmail(**email_data) email.send() + # try to see if we have the IP and that if the ssh keys can + # be configured + new_host = manager.get_primary_ipv4(vm_id) + if new_host is not None: + public_keys = get_all_public_keys(owner) + keys = [{'value': key, 'state': True} for key in public_keys] + logger.debug( + "Calling configure on {host} for {num_keys} keys".format( + host=new_host, num_keys=len(keys))) + # Let's delay the task by 75 seconds to be sure that we run + # the cdist configure after the host is up + manager.manage_public_key(keys, hosts=[new_host], countdown=75) + return HttpResponseRedirect( "{url}?{query_params}".format( url=reverse('hosting:orders', kwargs={'pk': order.id}), @@ -919,7 +931,8 @@ class VirtualMachineView(LoginRequiredMixin, View): 'order': HostingOrder.objects.get( vm_id=serializer.data['vm_id']) } - except: + except Exception as ex: + logger.debug("Exception generated {}".format(str(ex))) pass return render(request, self.template_name, context) diff --git a/opennebula_api/models.py b/opennebula_api/models.py index 60f3159c..d584bf26 100644 --- a/opennebula_api/models.py +++ b/opennebula_api/models.py @@ -1,13 +1,14 @@ -import oca -import socket import logging +import socket -from oca.pool import WrongNameError, WrongIdError -from oca.exceptions import OpenNebulaException - +import oca from django.conf import settings +from oca.exceptions import OpenNebulaException +from oca.pool import WrongNameError, WrongIdError +from hosting.models import HostingOrder from utils.models import CustomUser +from utils.tasks import save_ssh_key, save_ssh_key_error_handler from .exceptions import KeyExistsError, UserExistsError, UserCredentialError logger = logging.getLogger(__name__) @@ -17,7 +18,8 @@ class OpenNebulaManager(): """This class represents an opennebula manager.""" def __init__(self, email=None, password=None): - + self.email = email + self.password = password # Get oneadmin client self.oneadmin_client = self._get_opennebula_client( settings.OPENNEBULA_USERNAME, @@ -122,16 +124,19 @@ class OpenNebulaManager(): except WrongNameError: user_id = self.oneadmin_client.call(oca.User.METHODS['allocate'], - user.email, user.password, 'core') - logger.debug('Created a user for CustomObject: {user} with user id = {u_id}', - user=user, - u_id=user_id - ) + user.email, user.password, + 'core') + logger.debug( + 'Created a user for CustomObject: {user} with user id = {u_id}', + user=user, + u_id=user_id + ) return user_id except ConnectionRefusedError: - logger.error('Could not connect to host: {host} via protocol {protocol}'.format( - host=settings.OPENNEBULA_DOMAIN, - protocol=settings.OPENNEBULA_PROTOCOL) + logger.error( + 'Could not connect to host: {host} via protocol {protocol}'.format( + host=settings.OPENNEBULA_DOMAIN, + protocol=settings.OPENNEBULA_PROTOCOL) ) raise ConnectionRefusedError @@ -141,8 +146,9 @@ class OpenNebulaManager(): opennebula_user = user_pool.get_by_name(email) return opennebula_user except WrongNameError as wrong_name_err: - opennebula_user = self.oneadmin_client.call(oca.User.METHODS['allocate'], email, - password, 'core') + opennebula_user = self.oneadmin_client.call( + oca.User.METHODS['allocate'], email, + password, 'core') logger.debug( "User {0} does not exist. Created the user. User id = {1}", email, @@ -150,9 +156,10 @@ class OpenNebulaManager(): ) return opennebula_user except ConnectionRefusedError: - logger.info('Could not connect to host: {host} via protocol {protocol}'.format( - host=settings.OPENNEBULA_DOMAIN, - protocol=settings.OPENNEBULA_PROTOCOL) + logger.info( + 'Could not connect to host: {host} via protocol {protocol}'.format( + host=settings.OPENNEBULA_DOMAIN, + protocol=settings.OPENNEBULA_PROTOCOL) ) raise ConnectionRefusedError @@ -161,9 +168,10 @@ class OpenNebulaManager(): user_pool = oca.UserPool(self.oneadmin_client) user_pool.info() except ConnectionRefusedError: - logger.info('Could not connect to host: {host} via protocol {protocol}'.format( - host=settings.OPENNEBULA_DOMAIN, - protocol=settings.OPENNEBULA_PROTOCOL) + logger.info( + 'Could not connect to host: {host} via protocol {protocol}'.format( + host=settings.OPENNEBULA_DOMAIN, + protocol=settings.OPENNEBULA_PROTOCOL) ) raise return user_pool @@ -183,9 +191,10 @@ class OpenNebulaManager(): raise ConnectionRefusedError except ConnectionRefusedError: - logger.info('Could not connect to host: {host} via protocol {protocol}'.format( - host=settings.OPENNEBULA_DOMAIN, - protocol=settings.OPENNEBULA_PROTOCOL) + logger.info( + 'Could not connect to host: {host} via protocol {protocol}'.format( + host=settings.OPENNEBULA_DOMAIN, + protocol=settings.OPENNEBULA_PROTOCOL) ) raise ConnectionRefusedError # For now we'll just handle all other errors as connection errors @@ -208,6 +217,33 @@ class OpenNebulaManager(): except: raise ConnectionRefusedError + def get_primary_ipv4(self, vm_id): + """ + Returns the primary IPv4 of the given vm. + To be changed later. + + :return: An IP address string, if it exists else returns None + """ + all_ipv4s = self.get_vm_ipv4_addresses(vm_id) + if len(all_ipv4s) > 0: + return all_ipv4s[0] + else: + return None + + def get_vm_ipv4_addresses(self, vm_id): + """ + Returns a list of IPv4 addresses of the given vm + + :param vm_id: The ID of the vm + :return: + """ + ipv4s = [] + vm = self.get_vm(vm_id) + for nic in vm.template.nics: + if hasattr(nic, 'ip'): + ipv4s.append(nic.ip) + return ipv4s + def create_vm(self, template_id, specs, ssh_key=None, vm_name=None): template = self.get_template(template_id) @@ -258,7 +294,8 @@ class OpenNebulaManager(): vm_specs += "" if ssh_key: - vm_specs += "{ssh}".format(ssh=ssh_key) + vm_specs += "{ssh}".format( + ssh=ssh_key) vm_specs += """YES @@ -312,9 +349,11 @@ class OpenNebulaManager(): template_pool.info() return template_pool except ConnectionRefusedError: - logger.info('Could not connect to host: {host} via protocol {protocol}'.format( - host=settings.OPENNEBULA_DOMAIN, - protocol=settings.OPENNEBULA_PROTOCOL) + logger.info( + """Could not connect to host: {host} via protocol + {protocol}""".format( + host=settings.OPENNEBULA_DOMAIN, + protocol=settings.OPENNEBULA_PROTOCOL) ) raise ConnectionRefusedError except: @@ -347,7 +386,8 @@ class OpenNebulaManager(): except: raise ConnectionRefusedError - def create_template(self, name, cores, memory, disk_size, core_price, memory_price, + def create_template(self, name, cores, memory, disk_size, core_price, + memory_price, disk_size_price, ssh=''): """Create and add a new template to opennebula. :param name: A string representation describing the template. @@ -490,3 +530,57 @@ class OpenNebulaManager(): except ConnectionError: raise + + def manage_public_key(self, keys, hosts=None, countdown=0): + """ + A function that manages the supplied keys in the + authorized_keys file of the given list of hosts. If hosts + parameter is not supplied, all hosts of this customer + will be configured with the supplied keys + + :param keys: A list of ssh keys that are to be added/removed + A key should be a dict of the form + { + 'value': 'sha-.....', # public key as string + 'state': True # whether key is to be added or + } # removed + :param hosts: A list of hosts IP addresses + :param countdown: Parameter to be passed to celery apply_async + Allows to delay a task by `countdown` number of seconds + :return: + """ + if hosts is None: + hosts = self.get_all_hosts() + + if len(hosts) > 0 and len(keys) > 0: + save_ssh_key.apply_async((hosts, keys), countdown=countdown, + link_error=save_ssh_key_error_handler.s()) + else: + logger.debug( + "Keys and/or hosts are empty, so not managing any keys") + + def get_all_hosts(self): + """ + A utility function to obtain all hosts of this owner + :return: A list of hosts IP addresses, empty if none exist + """ + owner = CustomUser.objects.filter( + email=self.email).first() + all_orders = HostingOrder.objects.filter(customer__user=owner) + hosts = [] + if len(all_orders) > 0: + logger.debug("The user {} has 1 or more VMs. We need to configure " + "the ssh keys.".format(self.email)) + for order in all_orders: + try: + vm = self.get_vm(order.vm_id) + for nic in vm.template.nics: + if hasattr(nic, 'ip'): + hosts.append(nic.ip) + except WrongIdError: + logger.debug( + "VM with ID {} does not exist".format(order.vm_id)) + else: + logger.debug("The user {} has no VMs. We don't need to configure " + "the ssh keys.".format(self.email)) + return hosts diff --git a/requirements.txt b/requirements.txt index 8d9c68c5..6446a5c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -95,4 +95,6 @@ pycodestyle==2.3.1 pyflakes==1.5.0 billiard==3.5.0.3 amqp==2.2.1 -vine==1.1.4 \ No newline at end of file +vine==1.1.4 +#git+https://github.com/ungleich/cdist.git#egg=cdist +file:///home/app/cdist#egg=cdist diff --git a/utils/hosting_utils.py b/utils/hosting_utils.py new file mode 100644 index 00000000..7c1a83ad --- /dev/null +++ b/utils/hosting_utils.py @@ -0,0 +1,11 @@ +from hosting.models import UserHostingKey + + +def get_all_public_keys(customer): + """ + Returns all the public keys of the user + :param customer: The customer whose public keys are needed + :return: A list of public keys + """ + return UserHostingKey.objects.filter(user_id=customer.id).values_list( + "public_key", flat=True) diff --git a/utils/tasks.py b/utils/tasks.py index 5334b507..d66c37ee 100644 --- a/utils/tasks.py +++ b/utils/tasks.py @@ -1,8 +1,14 @@ +import tempfile + +import cdist +from cdist.integration import configure_hosts_simple +from celery.result import AsyncResult from celery.utils.log import get_task_logger from django.conf import settings -from dynamicweb.celery import app from django.core.mail import EmailMessage +from dynamicweb.celery import app + logger = get_task_logger(__name__) @@ -18,3 +24,72 @@ def send_plain_email_task(self, email_data): """ email = EmailMessage(**email_data) email.send() + + +@app.task(bind=True, max_retries=settings.CELERY_MAX_RETRIES) +def save_ssh_key(self, hosts, keys): + """ + Saves ssh key into the VMs of a user using cdist + + :param hosts: A list of hosts to be configured + :param keys: A list of keys to be added. A key should be dict of the + form { + 'value': 'sha-.....', # public key as string + 'state': True # whether key is to be added or + } # removed + """ + logger.debug("""Running save_ssh_key task for + Hosts: {hosts_str} + Keys: {keys_str}""".format(hosts_str=", ".join(hosts), + keys_str=", ".join([ + "{value}->{state}".format( + value=key.get('value'), + state=str( + key.get('state'))) + for key in keys])) + ) + return_value = True + with tempfile.NamedTemporaryFile(delete=True) as tmp_manifest: + # Generate manifest to be used for configuring the hosts + lines_list = [ + ' --key "{key}" --state {state} \\\n'.format( + key=key['value'], + state='present' if key['state'] else 'absent' + ).encode('utf-8') + for key in keys] + lines_list.insert(0, b'__ssh_authorized_keys root \\\n') + tmp_manifest.writelines(lines_list) + tmp_manifest.flush() + try: + configure_hosts_simple(hosts, + tmp_manifest.name, + verbose=cdist.argparse.VERBOSE_TRACE) + except Exception as cdist_exception: + logger.error(cdist_exception) + return_value = False + email_data = { + 'subject': "celery save_ssh_key error - task id {0}".format( + self.request.id.__str__()), + 'from_email': settings.DCL_SUPPORT_FROM_ADDRESS, + 'to': ['info@ungleich.ch'], + 'body': "Task Id: {0}\nResult: {1}\nTraceback: {2}".format( + self.request.id.__str__(), False, str(cdist_exception)), + } + send_plain_email_task(email_data) + return return_value + + +@app.task +def save_ssh_key_error_handler(uuid): + result = AsyncResult(uuid) + exc = result.get(propagate=False) + logger.error('Task {0} raised exception: {1!r}\n{2!r}'.format( + uuid, exc, result.traceback)) + email_data = { + 'subject': "[celery error] Save SSH key error {0}".format(uuid), + 'from_email': settings.DCL_SUPPORT_FROM_ADDRESS, + 'to': ['info@ungleich.ch'], + 'body': "Task Id: {0}\nResult: {1}\nTraceback: {2}".format( + uuid, exc, result.traceback), + } + send_plain_email_task(email_data) diff --git a/utils/tests.py b/utils/tests.py index c4608e73..d5c2d726 100644 --- a/utils/tests.py +++ b/utils/tests.py @@ -1,16 +1,20 @@ import uuid +from time import sleep from unittest.mock import patch import stripe +from celery.result import AsyncResult +from django.conf import settings from django.http.request import HttpRequest from django.test import Client -from django.test import TestCase +from django.test import TestCase, override_settings +from unittest import skipIf from model_mommy import mommy from datacenterlight.models import StripePlan from membership.models import StripeCustomer from utils.stripe_utils import StripeUtils -from django.conf import settings +from .tasks import save_ssh_key class BaseTestCase(TestCase): @@ -235,3 +239,57 @@ class StripePlanTestCase(TestStripeCustomerDescription): 'response_object').stripe_plan_id}]) self.assertIsNone(result.get('response_object'), None) self.assertIsNotNone(result.get('error')) + + +class SaveSSHKeyTestCase(TestCase): + """ + A test case to test the celery save_ssh_key task + """ + + @override_settings( + task_eager_propagates=True, + task_always_eager=True, + ) + def setUp(self): + self.public_key = settings.TEST_MANAGE_SSH_KEY_PUBKEY + self.hosts = settings.TEST_MANAGE_SSH_KEY_HOST + + @skipIf(settings.TEST_MANAGE_SSH_KEY_PUBKEY is None or + settings.TEST_MANAGE_SSH_KEY_PUBKEY == "" or + settings.TEST_MANAGE_SSH_KEY_HOST is None or + settings.TEST_MANAGE_SSH_KEY_HOST is "", + """Skipping test_save_ssh_key_add because either host + or public key were not specified or were empty""") + def test_save_ssh_key_add(self): + async_task = save_ssh_key.delay([self.hosts], + [{'value': self.public_key, + 'state': True}]) + save_ssh_key_result = None + for i in range(0, 10): + sleep(5) + res = AsyncResult(async_task.task_id) + if type(res.result) is bool: + save_ssh_key_result = res.result + break + self.assertIsNotNone(save_ssh_key, "save_ssh_key_result is None") + self.assertTrue(save_ssh_key_result, "save_ssh_key_result is False") + + @skipIf(settings.TEST_MANAGE_SSH_KEY_PUBKEY is None or + settings.TEST_MANAGE_SSH_KEY_PUBKEY == "" or + settings.TEST_MANAGE_SSH_KEY_HOST is None or + settings.TEST_MANAGE_SSH_KEY_HOST is "", + """Skipping test_save_ssh_key_add because either host + or public key were not specified or were empty""") + def test_save_ssh_key_remove(self): + async_task = save_ssh_key.delay([self.hosts], + [{'value': self.public_key, + 'state': False}]) + save_ssh_key_result = None + for i in range(0, 10): + sleep(5) + res = AsyncResult(async_task.task_id) + if type(res.result) is bool: + save_ssh_key_result = res.result + break + self.assertIsNotNone(save_ssh_key, "save_ssh_key_result is None") + self.assertTrue(save_ssh_key_result, "save_ssh_key_result is False")