diff --git a/hosting/forms.py b/hosting/forms.py index 464e2059..458bee6f 100644 --- a/hosting/forms.py +++ b/hosting/forms.py @@ -1,5 +1,8 @@ import datetime import logging +import subprocess +import tempfile +import os from django import forms from django.contrib.auth import authenticate @@ -92,25 +95,22 @@ class UserHostingKeyForm(forms.ModelForm): return self.data.get('public_key') KEY_ERROR_MESSAGE = _("Please input a proper SSH key") openssh_pubkey_str = self.data.get('public_key') - try: - ssh_key = SSHKey(openssh_pubkey_str) - ssh_key.parse() - except InvalidKeyException as err: - logger.error( - "InvalidKeyException while parsing ssh key {0}".format(err)) - raise forms.ValidationError(KEY_ERROR_MESSAGE) - except NotImplementedError as err: - logger.error( - "NotImplementedError while parsing ssh key {0}".format(err)) - raise forms.ValidationError(KEY_ERROR_MESSAGE) - except UnicodeDecodeError as u: - logger.error( - "UnicodeDecodeError while parsing ssh key {0}".format(u)) - raise forms.ValidationError(KEY_ERROR_MESSAGE) - except ValueError as v: - logger.error( - "ValueError while parsing ssh key {0}".format(v)) - raise forms.ValidationError(KEY_ERROR_MESSAGE) + + with tempfile.NamedTemporaryFile(delete=True) as tmp_public_key_file: + tmp_public_key_file.writelines(openssh_pubkey_str) + tmp_public_key_file.flush() + try: + out = subprocess.check_output( + ['ssh-keygen', '-lf', tmp_public_key_file.name]) + except subprocess.CalledProcessError as cpe: + logger.debug( + "Not a correct ssh format {error} {out}".format( + error=str(cpe), out=out)) + raise forms.ValidationError(KEY_ERROR_MESSAGE) + try: + os.remove(tmp_public_key_file.name) + except OSError: + pass return openssh_pubkey_str def clean_name(self): diff --git a/utils/tasks.py b/utils/tasks.py index 1844bc16..21d2c9b3 100644 --- a/utils/tasks.py +++ b/utils/tasks.py @@ -1,4 +1,5 @@ import tempfile +import os import cdist from cdist.integration import configure_hosts_simple @@ -67,6 +68,10 @@ def save_ssh_key(self, hosts, keys): except Exception as cdist_exception: logger.error(cdist_exception) return_value = False + try: + os.remove(tmp_manifest.name) + except OSError: + pass return return_value