import hashlib
import random
import base64
import sys

from ldap3 import Server, Connection, ObjectDef, Reader, ALL
from ldap3.core import exceptions

SALT_BYTES = 15


class LdapManager:
    def __init__(self, server, admin_dn, admin_password):
        self.server = Server(server, get_info=ALL)
        try:
            self.conn = Connection(server, admin_dn, admin_password, auto_bind=True)
        except exceptions.LDAPException as err:
            sys.exit(f'LDAP Error: {err}')

        self.person_obj_def = ObjectDef('inetOrgPerson', self.conn)

    def get(self, query=None, search_base='dc=ungleich,dc=ch'):
        kwargs = {
            'connection': self.conn,
            'object_def': self.person_obj_def,
            'base': search_base,
        }
        if query:
            kwargs['query'] = query
        r = Reader(**kwargs)
        return r.search()

    def is_password_valid(self, query_value, password, query_key='mail', **kwargs):
        entries = self.get(query='({}={})'.format(query_key, query_value), **kwargs)
        if entries:
            password_in_ldap = entries[0].userPassword.value
            found = self._check_password(password_in_ldap, password)
            if not found:
                raise Exception('Invalid Password')
            else:
                return entries[0]
        else:
            raise ValueError('Such {}={} not found'.format(query_key, query_value))

    @staticmethod
    def _check_password(tagged_digest_salt, password):
        digest_salt_b64 = tagged_digest_salt[6:]
        digest_salt = base64.decodebytes(digest_salt_b64)
        digest = digest_salt[:20]
        salt = digest_salt[20:]

        sha = hashlib.sha1(password.encode('utf-8'))
        sha.update(salt)

        return digest == sha.digest()

    @staticmethod
    def ssha_password(password):
        """
        Apply the SSHA password hashing scheme to the given *password*.
        *password* must be a :class:`bytes` object, containing the utf-8
        encoded password.

        Return a :class:`bytes` object containing ``ascii``-compatible data
        which can be used as LDAP value, e.g. after armoring it once more using
        base64 or decoding it to unicode from ``ascii``.
        """

        sha1 = hashlib.sha1()
        salt = random.SystemRandom().getrandbits(SALT_BYTES * 8).to_bytes(SALT_BYTES, 'little')
        sha1.update(password)
        sha1.update(salt)

        digest = sha1.digest()
        passwd = b'{SSHA}' + base64.b64encode(digest + salt)
        return passwd