From ef064299416c6818adabcb22adb3cff205e547a5 Mon Sep 17 00:00:00 2001 From: mathieui Date: Sat, 3 Apr 2021 19:12:59 +0200 Subject: slixmpp.util: type things Fix a bug in the SASL implementation as well. (some special chars would make things crash instead of being escaped) --- slixmpp/util/sasl/client.py | 32 +++++++++++++------ slixmpp/util/sasl/mechanisms.py | 68 +++++++++++++++++++++-------------------- 2 files changed, 57 insertions(+), 43 deletions(-) (limited to 'slixmpp/util/sasl') diff --git a/slixmpp/util/sasl/client.py b/slixmpp/util/sasl/client.py index 7c9d38e0..7565db6b 100644 --- a/slixmpp/util/sasl/client.py +++ b/slixmpp/util/sasl/client.py @@ -1,4 +1,3 @@ - # slixmpp.util.sasl.client # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # This module was originally based on Dave Cridland's Suelta library. @@ -6,9 +5,11 @@ # :copryight: (c) 2004-2013 David Alan Cridland # :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout # :license: MIT, see LICENSE for more details +from __future__ import annotations import logging import stringprep +from typing import Iterable, Set, Callable, Dict, Any, Optional, Type from slixmpp.util import hashes, bytes, stringprep_profiles @@ -16,11 +17,11 @@ log = logging.getLogger(__name__) #: Global registry mapping mechanism names to implementation classes. -MECHANISMS = {} +MECHANISMS: Dict[str, Type[Mech]] = {} #: Global registry mapping mechanism names to security scores. -MECH_SEC_SCORES = {} +MECH_SEC_SCORES: Dict[str, int] = {} #: The SASLprep profile of stringprep used to validate simple username @@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create( unassigned=[stringprep.in_table_a1]) -def sasl_mech(score): +def sasl_mech(score: int): sec_score = score - def register(mech): + + def register(mech: Type[Mech]): n = 0 mech.score = sec_score if mech.use_hashes: @@ -99,9 +101,9 @@ class Mech(object): score = -1 use_hashes = False channel_binding = False - required_credentials = set() - optional_credentials = set() - security = set() + required_credentials: Set[str] = set() + optional_credentials: Set[str] = set() + security: Set[str] = set() def __init__(self, name, credentials, security_settings): self.credentials = credentials @@ -118,7 +120,14 @@ class Mech(object): return b'' -def choose(mech_list, credentials, security_settings, limit=None, min_mech=None): +CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]] +SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]] + + +def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback, + security_settings: SecurityCallback, + limit: Optional[Iterable[Type[Mech]]] = None, + min_mech: Optional[str] = None) -> Mech: available_mechs = set(MECHANISMS.keys()) if limit is None: limit = set(mech_list) @@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None) mech_list = mech_list.intersection(limit) available_mechs = available_mechs.intersection(mech_list) - best_score = MECH_SEC_SCORES.get(min_mech, -1) + if min_mech is None: + best_score = -1 + else: + best_score = MECH_SEC_SCORES.get(min_mech, -1) best_mech = None for name in available_mechs: if name in MECH_SEC_SCORES: diff --git a/slixmpp/util/sasl/mechanisms.py b/slixmpp/util/sasl/mechanisms.py index 53f39395..d53caec8 100644 --- a/slixmpp/util/sasl/mechanisms.py +++ b/slixmpp/util/sasl/mechanisms.py @@ -11,6 +11,9 @@ import hmac import random from base64 import b64encode, b64decode +from typing import List, Dict, Optional + +bytes_ = bytes from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes from slixmpp.util.sasl.client import sasl_mech, Mech, \ @@ -63,7 +66,7 @@ class PLAIN(Mech): if not self.security_settings['encrypted_plain']: raise SASLCancelled('PLAIN with encryption') - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> bytes_: authzid = self.credentials['authzid'] authcid = self.credentials['username'] password = self.credentials['password'] @@ -148,7 +151,7 @@ class CRAM(Mech): required_credentials = {'username', 'password'} security = {'encrypted', 'unencrypted_cram'} - def setup(self, name): + def setup(self, name: str): self.hash_name = name[5:] self.hash = hash(self.hash_name) if self.hash is None: @@ -157,14 +160,14 @@ class CRAM(Mech): if not self.security_settings['unencrypted_cram']: raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name) - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: if not challenge: return None username = self.credentials['username'] password = self.credentials['password'] - mac = hmac.HMAC(key=password, digestmod=self.hash) + mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore mac.update(challenge) return username + b' ' + bytes(mac.hexdigest()) @@ -201,43 +204,42 @@ class SCRAM(Mech): def HMAC(self, key, msg): return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest() - def Hi(self, text, salt, iterations): - text = bytes(text) - ui1 = self.HMAC(text, salt + b'\0\0\0\01') + def Hi(self, text: str, salt: bytes_, iterations: int): + text_enc = bytes(text) + ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01') ui = ui1 for i in range(iterations - 1): - ui1 = self.HMAC(text, ui1) + ui1 = self.HMAC(text_enc, ui1) ui = XOR(ui, ui1) return ui - def H(self, text): + def H(self, text: str) -> bytes_: return self.hash(text).digest() - def saslname(self, value): - value = value.decode("utf-8") - escaped = [] + def saslname(self, value_b: bytes_) -> bytes_: + value = value_b.decode("utf-8") + escaped: List[str] = [] for char in value: if char == ',': - escaped += b'=2C' + escaped.append('=2C') elif char == '=': - escaped += b'=3D' + escaped.append('=3D') else: - escaped += char + escaped.append(char) return "".join(escaped).encode("utf-8") - def parse(self, challenge): + def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]: items = {} for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]: items[key] = value return items - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b''): steps = [self.process_1, self.process_2, self.process_3] return steps[self.step](challenge) - def process_1(self, challenge): + def process_1(self, challenge: bytes_) -> bytes_: self.step = 1 - data = {} self.cnonce = bytes(('%s' % random.random())[2:]) @@ -263,7 +265,7 @@ class SCRAM(Mech): return self.client_first_message - def process_2(self, challenge): + def process_2(self, challenge: bytes_) -> bytes_: self.step = 2 data = self.parse(challenge) @@ -304,7 +306,7 @@ class SCRAM(Mech): return client_final_message - def process_3(self, challenge): + def process_3(self, challenge: bytes_) -> bytes_: data = self.parse(challenge) verifier = data.get(b'v', None) error = data.get(b'e', 'Unknown error') @@ -345,17 +347,16 @@ class DIGEST(Mech): self.cnonce = b'' self.nonce_count = 1 - def parse(self, challenge=b''): - data = {} + def parse(self, challenge: bytes_ = b''): + data: Dict[str, bytes_] = {} var_name = b'' var_value = b'' # States: var, new_var, end, quote, escaped_quote state = 'var' - - for char in challenge: - char = bytes([char]) + for char_int in challenge: + char = bytes_([char_int]) if state == 'var': if char.isspace(): @@ -401,14 +402,14 @@ class DIGEST(Mech): state = 'var' return data - def MAC(self, key, seq, msg): + def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_: mac = hmac.HMAC(key=key, digestmod=self.hash) seqnum = num_to_bytes(seq) mac.update(seqnum) mac.update(msg) return mac.digest()[:10] + b'\x00\x01' + seqnum - def A1(self): + def A1(self) -> bytes_: username = self.credentials['username'] password = self.credentials['password'] authzid = self.credentials['authzid'] @@ -423,13 +424,13 @@ class DIGEST(Mech): return bytes(a1) - def A2(self, prefix=b''): + def A2(self, prefix: bytes_ = b'') -> bytes_: a2 = prefix + b':' + self.digest_uri() if self.qop in (b'auth-int', b'auth-conf'): a2 += b':00000000000000000000000000000000' return bytes(a2) - def response(self, prefix=b''): + def response(self, prefix: bytes_ = b'') -> bytes_: nc = bytes('%08x' % self.nonce_count) a1 = bytes(self.hash(self.A1()).hexdigest().lower()) @@ -439,7 +440,7 @@ class DIGEST(Mech): return bytes(self.hash(a1 + b':' + s).hexdigest().lower()) - def digest_uri(self): + def digest_uri(self) -> bytes_: serv_type = self.credentials['service'] serv_name = self.credentials['service-name'] host = self.credentials['host'] @@ -449,7 +450,7 @@ class DIGEST(Mech): uri += b'/' + serv_name return uri - def respond(self): + def respond(self) -> bytes_: data = { 'username': quote(self.credentials['username']), 'authzid': quote(self.credentials['authzid']), @@ -469,7 +470,7 @@ class DIGEST(Mech): resp += b',' + bytes(key) + b'=' + bytes(value) return resp[1:] - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: if not challenge: if self.cnonce and self.nonce and self.nonce_count and self.qop: self.nonce_count += 1 @@ -480,6 +481,7 @@ class DIGEST(Mech): if 'rspauth' in data: if data['rspauth'] != self.response(): raise SASLMutualAuthFailed() + return None else: self.nonce_count = 1 self.cnonce = bytes('%s' % random.random())[2:] -- cgit v1.2.3