From ef064299416c6818adabcb22adb3cff205e547a5 Mon Sep 17 00:00:00 2001 From: mathieui Date: Sat, 3 Apr 2021 19:12:59 +0200 Subject: slixmpp.util: type things Fix a bug in the SASL implementation as well. (some special chars would make things crash instead of being escaped) --- slixmpp/util/cache.py | 24 ++++++++++----- slixmpp/util/misc_ops.py | 53 ++++++++++---------------------- slixmpp/util/sasl/client.py | 32 +++++++++++++------ slixmpp/util/sasl/mechanisms.py | 68 +++++++++++++++++++++-------------------- 4 files changed, 89 insertions(+), 88 deletions(-) diff --git a/slixmpp/util/cache.py b/slixmpp/util/cache.py index 23592404..b7042a56 100644 --- a/slixmpp/util/cache.py +++ b/slixmpp/util/cache.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2018 Emmanuel Gil Peyrot # This file is part of Slixmpp. @@ -6,8 +5,11 @@ import os import logging +from typing import Callable, Optional, Any + log = logging.getLogger(__name__) + class Cache: def retrieve(self, key): raise NotImplementedError @@ -16,7 +18,8 @@ class Cache: raise NotImplementedError def remove(self, key): - raise NotImplemented + raise NotImplementedError + class PerJidCache: def retrieve_by_jid(self, jid, key): @@ -28,6 +31,7 @@ class PerJidCache: def remove_by_jid(self, jid, key): raise NotImplementedError + class MemoryCache(Cache): def __init__(self): self.cache = {} @@ -44,6 +48,7 @@ class MemoryCache(Cache): del self.cache[key] return True + class MemoryPerJidCache(PerJidCache): def __init__(self): self.cache = {} @@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache): del cache[key] return True + class FileSystemStorage: - def __init__(self, encode, decode, binary): + def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool): self.encode = encode if encode is not None else lambda x: x self.decode = decode if decode is not None else lambda x: x self.read = 'rb' if binary else 'r' self.write = 'wb' if binary else 'w' - def _retrieve(self, directory, key): + def _retrieve(self, directory: str, key: str): filename = os.path.join(directory, key.replace('/', '_')) try: with open(filename, self.read) as cache_file: @@ -86,7 +92,7 @@ class FileSystemStorage: log.debug('Removing %s entry', key) self._remove(directory, key) - def _store(self, directory, key, value): + def _store(self, directory: str, key: str, value): filename = os.path.join(directory, key.replace('/', '_')) try: os.makedirs(directory, exist_ok=True) @@ -99,7 +105,7 @@ class FileSystemStorage: except Exception: log.debug('Failed to encode %s to cache:', key, exc_info=True) - def _remove(self, directory, key): + def _remove(self, directory: str, key: str): filename = os.path.join(directory, key.replace('/', '_')) try: os.remove(filename) @@ -108,8 +114,9 @@ class FileSystemStorage: return False return True + class FileSystemCache(Cache, FileSystemStorage): - def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False): + def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False): FileSystemStorage.__init__(self, encode, decode, binary) self.base_dir = os.path.join(directory, cache_type) @@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage): def remove(self, key): return self._remove(self.base_dir, key) + class FileSystemPerJidCache(PerJidCache, FileSystemStorage): - def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False): + def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False): FileSystemStorage.__init__(self, encode, decode, binary) self.base_dir = os.path.join(directory, cache_type) diff --git a/slixmpp/util/misc_ops.py b/slixmpp/util/misc_ops.py index 1dcd6e3f..ed16d347 100644 --- a/slixmpp/util/misc_ops.py +++ b/slixmpp/util/misc_ops.py @@ -2,15 +2,19 @@ import builtins import sys import hashlib +from typing import Optional, Union, Callable, List -def unicode(text): +bytes_ = builtins.bytes # alias the stdlib type but ew + + +def unicode(text: Union[bytes_, str]) -> str: if not isinstance(text, str): return text.decode('utf-8') else: return text -def bytes(text): +def bytes(text: Optional[Union[str, bytes_]]) -> bytes_: """ Convert Unicode text to UTF-8 encoded bytes. @@ -34,7 +38,7 @@ def bytes(text): return builtins.bytes(text, encoding='utf-8') -def quote(text): +def quote(text: Union[str, bytes_]) -> bytes_: """ Enclose in quotes and escape internal slashes and double quotes. @@ -44,7 +48,7 @@ def quote(text): return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"' -def num_to_bytes(num): +def num_to_bytes(num: int) -> bytes_: """ Convert an integer into a four byte sequence. @@ -58,21 +62,21 @@ def num_to_bytes(num): return bval -def bytes_to_num(bval): +def bytes_to_num(bval: bytes_) -> int: """ Convert a four byte sequence to an integer. :param bytes bval: A four byte sequence to turn into an integer. """ num = 0 - num += ord(bval[0] << 24) - num += ord(bval[1] << 16) - num += ord(bval[2] << 8) - num += ord(bval[3]) + num += (bval[0] << 24) + num += (bval[1] << 16) + num += (bval[2] << 8) + num += (bval[3]) return num -def XOR(x, y): +def XOR(x: bytes_, y: bytes_) -> bytes_: """ Return the results of an XOR operation on two equal length byte strings. @@ -85,7 +89,7 @@ def XOR(x, y): return builtins.bytes([a ^ b for a, b in zip(x, y)]) -def hash(name): +def hash(name: str) -> Optional[Callable]: """ Return a hash function implementing the given algorithm. @@ -102,7 +106,7 @@ def hash(name): return None -def hashes(): +def hashes() -> List[str]: """ Return a list of available hashing algorithms. @@ -115,28 +119,3 @@ def hashes(): t += ['MD2'] hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')] return t + hashes - - -def setdefaultencoding(encoding): - """ - Set the current default string encoding used by the Unicode implementation. - - Actually calls sys.setdefaultencoding under the hood - see the docs for that - for more details. This method exists only as a way to call find/call it - even after it has been 'deleted' when the site module is executed. - - :param string encoding: An encoding name, compatible with sys.setdefaultencoding - """ - func = getattr(sys, 'setdefaultencoding', None) - if func is None: - import gc - import types - for obj in gc.get_objects(): - if (isinstance(obj, types.BuiltinFunctionType) - and obj.__name__ == 'setdefaultencoding'): - func = obj - break - if func is None: - raise RuntimeError("Could not find setdefaultencoding") - sys.setdefaultencoding = func - return func(encoding) diff --git a/slixmpp/util/sasl/client.py b/slixmpp/util/sasl/client.py index 7c9d38e0..7565db6b 100644 --- a/slixmpp/util/sasl/client.py +++ b/slixmpp/util/sasl/client.py @@ -1,4 +1,3 @@ - # slixmpp.util.sasl.client # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # This module was originally based on Dave Cridland's Suelta library. @@ -6,9 +5,11 @@ # :copryight: (c) 2004-2013 David Alan Cridland # :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout # :license: MIT, see LICENSE for more details +from __future__ import annotations import logging import stringprep +from typing import Iterable, Set, Callable, Dict, Any, Optional, Type from slixmpp.util import hashes, bytes, stringprep_profiles @@ -16,11 +17,11 @@ log = logging.getLogger(__name__) #: Global registry mapping mechanism names to implementation classes. -MECHANISMS = {} +MECHANISMS: Dict[str, Type[Mech]] = {} #: Global registry mapping mechanism names to security scores. -MECH_SEC_SCORES = {} +MECH_SEC_SCORES: Dict[str, int] = {} #: The SASLprep profile of stringprep used to validate simple username @@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create( unassigned=[stringprep.in_table_a1]) -def sasl_mech(score): +def sasl_mech(score: int): sec_score = score - def register(mech): + + def register(mech: Type[Mech]): n = 0 mech.score = sec_score if mech.use_hashes: @@ -99,9 +101,9 @@ class Mech(object): score = -1 use_hashes = False channel_binding = False - required_credentials = set() - optional_credentials = set() - security = set() + required_credentials: Set[str] = set() + optional_credentials: Set[str] = set() + security: Set[str] = set() def __init__(self, name, credentials, security_settings): self.credentials = credentials @@ -118,7 +120,14 @@ class Mech(object): return b'' -def choose(mech_list, credentials, security_settings, limit=None, min_mech=None): +CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]] +SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]] + + +def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback, + security_settings: SecurityCallback, + limit: Optional[Iterable[Type[Mech]]] = None, + min_mech: Optional[str] = None) -> Mech: available_mechs = set(MECHANISMS.keys()) if limit is None: limit = set(mech_list) @@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None) mech_list = mech_list.intersection(limit) available_mechs = available_mechs.intersection(mech_list) - best_score = MECH_SEC_SCORES.get(min_mech, -1) + if min_mech is None: + best_score = -1 + else: + best_score = MECH_SEC_SCORES.get(min_mech, -1) best_mech = None for name in available_mechs: if name in MECH_SEC_SCORES: diff --git a/slixmpp/util/sasl/mechanisms.py b/slixmpp/util/sasl/mechanisms.py index 53f39395..d53caec8 100644 --- a/slixmpp/util/sasl/mechanisms.py +++ b/slixmpp/util/sasl/mechanisms.py @@ -11,6 +11,9 @@ import hmac import random from base64 import b64encode, b64decode +from typing import List, Dict, Optional + +bytes_ = bytes from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes from slixmpp.util.sasl.client import sasl_mech, Mech, \ @@ -63,7 +66,7 @@ class PLAIN(Mech): if not self.security_settings['encrypted_plain']: raise SASLCancelled('PLAIN with encryption') - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> bytes_: authzid = self.credentials['authzid'] authcid = self.credentials['username'] password = self.credentials['password'] @@ -148,7 +151,7 @@ class CRAM(Mech): required_credentials = {'username', 'password'} security = {'encrypted', 'unencrypted_cram'} - def setup(self, name): + def setup(self, name: str): self.hash_name = name[5:] self.hash = hash(self.hash_name) if self.hash is None: @@ -157,14 +160,14 @@ class CRAM(Mech): if not self.security_settings['unencrypted_cram']: raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name) - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: if not challenge: return None username = self.credentials['username'] password = self.credentials['password'] - mac = hmac.HMAC(key=password, digestmod=self.hash) + mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore mac.update(challenge) return username + b' ' + bytes(mac.hexdigest()) @@ -201,43 +204,42 @@ class SCRAM(Mech): def HMAC(self, key, msg): return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest() - def Hi(self, text, salt, iterations): - text = bytes(text) - ui1 = self.HMAC(text, salt + b'\0\0\0\01') + def Hi(self, text: str, salt: bytes_, iterations: int): + text_enc = bytes(text) + ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01') ui = ui1 for i in range(iterations - 1): - ui1 = self.HMAC(text, ui1) + ui1 = self.HMAC(text_enc, ui1) ui = XOR(ui, ui1) return ui - def H(self, text): + def H(self, text: str) -> bytes_: return self.hash(text).digest() - def saslname(self, value): - value = value.decode("utf-8") - escaped = [] + def saslname(self, value_b: bytes_) -> bytes_: + value = value_b.decode("utf-8") + escaped: List[str] = [] for char in value: if char == ',': - escaped += b'=2C' + escaped.append('=2C') elif char == '=': - escaped += b'=3D' + escaped.append('=3D') else: - escaped += char + escaped.append(char) return "".join(escaped).encode("utf-8") - def parse(self, challenge): + def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]: items = {} for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]: items[key] = value return items - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b''): steps = [self.process_1, self.process_2, self.process_3] return steps[self.step](challenge) - def process_1(self, challenge): + def process_1(self, challenge: bytes_) -> bytes_: self.step = 1 - data = {} self.cnonce = bytes(('%s' % random.random())[2:]) @@ -263,7 +265,7 @@ class SCRAM(Mech): return self.client_first_message - def process_2(self, challenge): + def process_2(self, challenge: bytes_) -> bytes_: self.step = 2 data = self.parse(challenge) @@ -304,7 +306,7 @@ class SCRAM(Mech): return client_final_message - def process_3(self, challenge): + def process_3(self, challenge: bytes_) -> bytes_: data = self.parse(challenge) verifier = data.get(b'v', None) error = data.get(b'e', 'Unknown error') @@ -345,17 +347,16 @@ class DIGEST(Mech): self.cnonce = b'' self.nonce_count = 1 - def parse(self, challenge=b''): - data = {} + def parse(self, challenge: bytes_ = b''): + data: Dict[str, bytes_] = {} var_name = b'' var_value = b'' # States: var, new_var, end, quote, escaped_quote state = 'var' - - for char in challenge: - char = bytes([char]) + for char_int in challenge: + char = bytes_([char_int]) if state == 'var': if char.isspace(): @@ -401,14 +402,14 @@ class DIGEST(Mech): state = 'var' return data - def MAC(self, key, seq, msg): + def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_: mac = hmac.HMAC(key=key, digestmod=self.hash) seqnum = num_to_bytes(seq) mac.update(seqnum) mac.update(msg) return mac.digest()[:10] + b'\x00\x01' + seqnum - def A1(self): + def A1(self) -> bytes_: username = self.credentials['username'] password = self.credentials['password'] authzid = self.credentials['authzid'] @@ -423,13 +424,13 @@ class DIGEST(Mech): return bytes(a1) - def A2(self, prefix=b''): + def A2(self, prefix: bytes_ = b'') -> bytes_: a2 = prefix + b':' + self.digest_uri() if self.qop in (b'auth-int', b'auth-conf'): a2 += b':00000000000000000000000000000000' return bytes(a2) - def response(self, prefix=b''): + def response(self, prefix: bytes_ = b'') -> bytes_: nc = bytes('%08x' % self.nonce_count) a1 = bytes(self.hash(self.A1()).hexdigest().lower()) @@ -439,7 +440,7 @@ class DIGEST(Mech): return bytes(self.hash(a1 + b':' + s).hexdigest().lower()) - def digest_uri(self): + def digest_uri(self) -> bytes_: serv_type = self.credentials['service'] serv_name = self.credentials['service-name'] host = self.credentials['host'] @@ -449,7 +450,7 @@ class DIGEST(Mech): uri += b'/' + serv_name return uri - def respond(self): + def respond(self) -> bytes_: data = { 'username': quote(self.credentials['username']), 'authzid': quote(self.credentials['authzid']), @@ -469,7 +470,7 @@ class DIGEST(Mech): resp += b',' + bytes(key) + b'=' + bytes(value) return resp[1:] - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: if not challenge: if self.cnonce and self.nonce and self.nonce_count and self.qop: self.nonce_count += 1 @@ -480,6 +481,7 @@ class DIGEST(Mech): if 'rspauth' in data: if data['rspauth'] != self.response(): raise SASLMutualAuthFailed() + return None else: self.nonce_count = 1 self.cnonce = bytes('%s' % random.random())[2:] -- cgit v1.2.3