From ef064299416c6818adabcb22adb3cff205e547a5 Mon Sep 17 00:00:00 2001 From: mathieui Date: Sat, 3 Apr 2021 19:12:59 +0200 Subject: slixmpp.util: type things Fix a bug in the SASL implementation as well. (some special chars would make things crash instead of being escaped) --- slixmpp/util/sasl/mechanisms.py | 68 +++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 33 deletions(-) (limited to 'slixmpp/util/sasl/mechanisms.py') diff --git a/slixmpp/util/sasl/mechanisms.py b/slixmpp/util/sasl/mechanisms.py index 53f39395..d53caec8 100644 --- a/slixmpp/util/sasl/mechanisms.py +++ b/slixmpp/util/sasl/mechanisms.py @@ -11,6 +11,9 @@ import hmac import random from base64 import b64encode, b64decode +from typing import List, Dict, Optional + +bytes_ = bytes from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes from slixmpp.util.sasl.client import sasl_mech, Mech, \ @@ -63,7 +66,7 @@ class PLAIN(Mech): if not self.security_settings['encrypted_plain']: raise SASLCancelled('PLAIN with encryption') - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> bytes_: authzid = self.credentials['authzid'] authcid = self.credentials['username'] password = self.credentials['password'] @@ -148,7 +151,7 @@ class CRAM(Mech): required_credentials = {'username', 'password'} security = {'encrypted', 'unencrypted_cram'} - def setup(self, name): + def setup(self, name: str): self.hash_name = name[5:] self.hash = hash(self.hash_name) if self.hash is None: @@ -157,14 +160,14 @@ class CRAM(Mech): if not self.security_settings['unencrypted_cram']: raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name) - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: if not challenge: return None username = self.credentials['username'] password = self.credentials['password'] - mac = hmac.HMAC(key=password, digestmod=self.hash) + mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore mac.update(challenge) return username + b' ' + bytes(mac.hexdigest()) @@ -201,43 +204,42 @@ class SCRAM(Mech): def HMAC(self, key, msg): return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest() - def Hi(self, text, salt, iterations): - text = bytes(text) - ui1 = self.HMAC(text, salt + b'\0\0\0\01') + def Hi(self, text: str, salt: bytes_, iterations: int): + text_enc = bytes(text) + ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01') ui = ui1 for i in range(iterations - 1): - ui1 = self.HMAC(text, ui1) + ui1 = self.HMAC(text_enc, ui1) ui = XOR(ui, ui1) return ui - def H(self, text): + def H(self, text: str) -> bytes_: return self.hash(text).digest() - def saslname(self, value): - value = value.decode("utf-8") - escaped = [] + def saslname(self, value_b: bytes_) -> bytes_: + value = value_b.decode("utf-8") + escaped: List[str] = [] for char in value: if char == ',': - escaped += b'=2C' + escaped.append('=2C') elif char == '=': - escaped += b'=3D' + escaped.append('=3D') else: - escaped += char + escaped.append(char) return "".join(escaped).encode("utf-8") - def parse(self, challenge): + def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]: items = {} for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]: items[key] = value return items - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b''): steps = [self.process_1, self.process_2, self.process_3] return steps[self.step](challenge) - def process_1(self, challenge): + def process_1(self, challenge: bytes_) -> bytes_: self.step = 1 - data = {} self.cnonce = bytes(('%s' % random.random())[2:]) @@ -263,7 +265,7 @@ class SCRAM(Mech): return self.client_first_message - def process_2(self, challenge): + def process_2(self, challenge: bytes_) -> bytes_: self.step = 2 data = self.parse(challenge) @@ -304,7 +306,7 @@ class SCRAM(Mech): return client_final_message - def process_3(self, challenge): + def process_3(self, challenge: bytes_) -> bytes_: data = self.parse(challenge) verifier = data.get(b'v', None) error = data.get(b'e', 'Unknown error') @@ -345,17 +347,16 @@ class DIGEST(Mech): self.cnonce = b'' self.nonce_count = 1 - def parse(self, challenge=b''): - data = {} + def parse(self, challenge: bytes_ = b''): + data: Dict[str, bytes_] = {} var_name = b'' var_value = b'' # States: var, new_var, end, quote, escaped_quote state = 'var' - - for char in challenge: - char = bytes([char]) + for char_int in challenge: + char = bytes_([char_int]) if state == 'var': if char.isspace(): @@ -401,14 +402,14 @@ class DIGEST(Mech): state = 'var' return data - def MAC(self, key, seq, msg): + def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_: mac = hmac.HMAC(key=key, digestmod=self.hash) seqnum = num_to_bytes(seq) mac.update(seqnum) mac.update(msg) return mac.digest()[:10] + b'\x00\x01' + seqnum - def A1(self): + def A1(self) -> bytes_: username = self.credentials['username'] password = self.credentials['password'] authzid = self.credentials['authzid'] @@ -423,13 +424,13 @@ class DIGEST(Mech): return bytes(a1) - def A2(self, prefix=b''): + def A2(self, prefix: bytes_ = b'') -> bytes_: a2 = prefix + b':' + self.digest_uri() if self.qop in (b'auth-int', b'auth-conf'): a2 += b':00000000000000000000000000000000' return bytes(a2) - def response(self, prefix=b''): + def response(self, prefix: bytes_ = b'') -> bytes_: nc = bytes('%08x' % self.nonce_count) a1 = bytes(self.hash(self.A1()).hexdigest().lower()) @@ -439,7 +440,7 @@ class DIGEST(Mech): return bytes(self.hash(a1 + b':' + s).hexdigest().lower()) - def digest_uri(self): + def digest_uri(self) -> bytes_: serv_type = self.credentials['service'] serv_name = self.credentials['service-name'] host = self.credentials['host'] @@ -449,7 +450,7 @@ class DIGEST(Mech): uri += b'/' + serv_name return uri - def respond(self): + def respond(self) -> bytes_: data = { 'username': quote(self.credentials['username']), 'authzid': quote(self.credentials['authzid']), @@ -469,7 +470,7 @@ class DIGEST(Mech): resp += b',' + bytes(key) + b'=' + bytes(value) return resp[1:] - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: if not challenge: if self.cnonce and self.nonce and self.nonce_count and self.qop: self.nonce_count += 1 @@ -480,6 +481,7 @@ class DIGEST(Mech): if 'rspauth' in data: if data['rspauth'] != self.response(): raise SASLMutualAuthFailed() + return None else: self.nonce_count = 1 self.cnonce = bytes('%s' % random.random())[2:] -- cgit v1.2.3