summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieui <mathieui@mathieui.net>2021-04-03 19:12:59 +0200
committermathieui <mathieui@mathieui.net>2021-06-28 21:00:30 +0200
commitef064299416c6818adabcb22adb3cff205e547a5 (patch)
tree1b8f5edd6a49dcc529c5293a9993c6ea03dd5296
parentb1411d8ed79792c6839f4aace13061256337e69b (diff)
downloadslixmpp-ef064299416c6818adabcb22adb3cff205e547a5.tar.gz
slixmpp-ef064299416c6818adabcb22adb3cff205e547a5.tar.bz2
slixmpp-ef064299416c6818adabcb22adb3cff205e547a5.tar.xz
slixmpp-ef064299416c6818adabcb22adb3cff205e547a5.zip
slixmpp.util: type things
Fix a bug in the SASL implementation as well. (some special chars would make things crash instead of being escaped)
-rw-r--r--slixmpp/util/cache.py24
-rw-r--r--slixmpp/util/misc_ops.py53
-rw-r--r--slixmpp/util/sasl/client.py32
-rw-r--r--slixmpp/util/sasl/mechanisms.py68
4 files changed, 89 insertions, 88 deletions
diff --git a/slixmpp/util/cache.py b/slixmpp/util/cache.py
index 23592404..b7042a56 100644
--- a/slixmpp/util/cache.py
+++ b/slixmpp/util/cache.py
@@ -1,4 +1,3 @@
-
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2018 Emmanuel Gil Peyrot
# This file is part of Slixmpp.
@@ -6,8 +5,11 @@
import os
import logging
+from typing import Callable, Optional, Any
+
log = logging.getLogger(__name__)
+
class Cache:
def retrieve(self, key):
raise NotImplementedError
@@ -16,7 +18,8 @@ class Cache:
raise NotImplementedError
def remove(self, key):
- raise NotImplemented
+ raise NotImplementedError
+
class PerJidCache:
def retrieve_by_jid(self, jid, key):
@@ -28,6 +31,7 @@ class PerJidCache:
def remove_by_jid(self, jid, key):
raise NotImplementedError
+
class MemoryCache(Cache):
def __init__(self):
self.cache = {}
@@ -44,6 +48,7 @@ class MemoryCache(Cache):
del self.cache[key]
return True
+
class MemoryPerJidCache(PerJidCache):
def __init__(self):
self.cache = {}
@@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache):
del cache[key]
return True
+
class FileSystemStorage:
- def __init__(self, encode, decode, binary):
+ def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool):
self.encode = encode if encode is not None else lambda x: x
self.decode = decode if decode is not None else lambda x: x
self.read = 'rb' if binary else 'r'
self.write = 'wb' if binary else 'w'
- def _retrieve(self, directory, key):
+ def _retrieve(self, directory: str, key: str):
filename = os.path.join(directory, key.replace('/', '_'))
try:
with open(filename, self.read) as cache_file:
@@ -86,7 +92,7 @@ class FileSystemStorage:
log.debug('Removing %s entry', key)
self._remove(directory, key)
- def _store(self, directory, key, value):
+ def _store(self, directory: str, key: str, value):
filename = os.path.join(directory, key.replace('/', '_'))
try:
os.makedirs(directory, exist_ok=True)
@@ -99,7 +105,7 @@ class FileSystemStorage:
except Exception:
log.debug('Failed to encode %s to cache:', key, exc_info=True)
- def _remove(self, directory, key):
+ def _remove(self, directory: str, key: str):
filename = os.path.join(directory, key.replace('/', '_'))
try:
os.remove(filename)
@@ -108,8 +114,9 @@ class FileSystemStorage:
return False
return True
+
class FileSystemCache(Cache, FileSystemStorage):
- def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
+ def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
FileSystemStorage.__init__(self, encode, decode, binary)
self.base_dir = os.path.join(directory, cache_type)
@@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage):
def remove(self, key):
return self._remove(self.base_dir, key)
+
class FileSystemPerJidCache(PerJidCache, FileSystemStorage):
- def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
+ def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
FileSystemStorage.__init__(self, encode, decode, binary)
self.base_dir = os.path.join(directory, cache_type)
diff --git a/slixmpp/util/misc_ops.py b/slixmpp/util/misc_ops.py
index 1dcd6e3f..ed16d347 100644
--- a/slixmpp/util/misc_ops.py
+++ b/slixmpp/util/misc_ops.py
@@ -2,15 +2,19 @@ import builtins
import sys
import hashlib
+from typing import Optional, Union, Callable, List
-def unicode(text):
+bytes_ = builtins.bytes # alias the stdlib type but ew
+
+
+def unicode(text: Union[bytes_, str]) -> str:
if not isinstance(text, str):
return text.decode('utf-8')
else:
return text
-def bytes(text):
+def bytes(text: Optional[Union[str, bytes_]]) -> bytes_:
"""
Convert Unicode text to UTF-8 encoded bytes.
@@ -34,7 +38,7 @@ def bytes(text):
return builtins.bytes(text, encoding='utf-8')
-def quote(text):
+def quote(text: Union[str, bytes_]) -> bytes_:
"""
Enclose in quotes and escape internal slashes and double quotes.
@@ -44,7 +48,7 @@ def quote(text):
return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"'
-def num_to_bytes(num):
+def num_to_bytes(num: int) -> bytes_:
"""
Convert an integer into a four byte sequence.
@@ -58,21 +62,21 @@ def num_to_bytes(num):
return bval
-def bytes_to_num(bval):
+def bytes_to_num(bval: bytes_) -> int:
"""
Convert a four byte sequence to an integer.
:param bytes bval: A four byte sequence to turn into an integer.
"""
num = 0
- num += ord(bval[0] << 24)
- num += ord(bval[1] << 16)
- num += ord(bval[2] << 8)
- num += ord(bval[3])
+ num += (bval[0] << 24)
+ num += (bval[1] << 16)
+ num += (bval[2] << 8)
+ num += (bval[3])
return num
-def XOR(x, y):
+def XOR(x: bytes_, y: bytes_) -> bytes_:
"""
Return the results of an XOR operation on two equal length byte strings.
@@ -85,7 +89,7 @@ def XOR(x, y):
return builtins.bytes([a ^ b for a, b in zip(x, y)])
-def hash(name):
+def hash(name: str) -> Optional[Callable]:
"""
Return a hash function implementing the given algorithm.
@@ -102,7 +106,7 @@ def hash(name):
return None
-def hashes():
+def hashes() -> List[str]:
"""
Return a list of available hashing algorithms.
@@ -115,28 +119,3 @@ def hashes():
t += ['MD2']
hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')]
return t + hashes
-
-
-def setdefaultencoding(encoding):
- """
- Set the current default string encoding used by the Unicode implementation.
-
- Actually calls sys.setdefaultencoding under the hood - see the docs for that
- for more details. This method exists only as a way to call find/call it
- even after it has been 'deleted' when the site module is executed.
-
- :param string encoding: An encoding name, compatible with sys.setdefaultencoding
- """
- func = getattr(sys, 'setdefaultencoding', None)
- if func is None:
- import gc
- import types
- for obj in gc.get_objects():
- if (isinstance(obj, types.BuiltinFunctionType)
- and obj.__name__ == 'setdefaultencoding'):
- func = obj
- break
- if func is None:
- raise RuntimeError("Could not find setdefaultencoding")
- sys.setdefaultencoding = func
- return func(encoding)
diff --git a/slixmpp/util/sasl/client.py b/slixmpp/util/sasl/client.py
index 7c9d38e0..7565db6b 100644
--- a/slixmpp/util/sasl/client.py
+++ b/slixmpp/util/sasl/client.py
@@ -1,4 +1,3 @@
-
# slixmpp.util.sasl.client
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# This module was originally based on Dave Cridland's Suelta library.
@@ -6,9 +5,11 @@
# :copryight: (c) 2004-2013 David Alan Cridland
# :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout
# :license: MIT, see LICENSE for more details
+from __future__ import annotations
import logging
import stringprep
+from typing import Iterable, Set, Callable, Dict, Any, Optional, Type
from slixmpp.util import hashes, bytes, stringprep_profiles
@@ -16,11 +17,11 @@ log = logging.getLogger(__name__)
#: Global registry mapping mechanism names to implementation classes.
-MECHANISMS = {}
+MECHANISMS: Dict[str, Type[Mech]] = {}
#: Global registry mapping mechanism names to security scores.
-MECH_SEC_SCORES = {}
+MECH_SEC_SCORES: Dict[str, int] = {}
#: The SASLprep profile of stringprep used to validate simple username
@@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create(
unassigned=[stringprep.in_table_a1])
-def sasl_mech(score):
+def sasl_mech(score: int):
sec_score = score
- def register(mech):
+
+ def register(mech: Type[Mech]):
n = 0
mech.score = sec_score
if mech.use_hashes:
@@ -99,9 +101,9 @@ class Mech(object):
score = -1
use_hashes = False
channel_binding = False
- required_credentials = set()
- optional_credentials = set()
- security = set()
+ required_credentials: Set[str] = set()
+ optional_credentials: Set[str] = set()
+ security: Set[str] = set()
def __init__(self, name, credentials, security_settings):
self.credentials = credentials
@@ -118,7 +120,14 @@ class Mech(object):
return b''
-def choose(mech_list, credentials, security_settings, limit=None, min_mech=None):
+CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]]
+SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]]
+
+
+def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback,
+ security_settings: SecurityCallback,
+ limit: Optional[Iterable[Type[Mech]]] = None,
+ min_mech: Optional[str] = None) -> Mech:
available_mechs = set(MECHANISMS.keys())
if limit is None:
limit = set(mech_list)
@@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None)
mech_list = mech_list.intersection(limit)
available_mechs = available_mechs.intersection(mech_list)
- best_score = MECH_SEC_SCORES.get(min_mech, -1)
+ if min_mech is None:
+ best_score = -1
+ else:
+ best_score = MECH_SEC_SCORES.get(min_mech, -1)
best_mech = None
for name in available_mechs:
if name in MECH_SEC_SCORES:
diff --git a/slixmpp/util/sasl/mechanisms.py b/slixmpp/util/sasl/mechanisms.py
index 53f39395..d53caec8 100644
--- a/slixmpp/util/sasl/mechanisms.py
+++ b/slixmpp/util/sasl/mechanisms.py
@@ -11,6 +11,9 @@ import hmac
import random
from base64 import b64encode, b64decode
+from typing import List, Dict, Optional
+
+bytes_ = bytes
from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes
from slixmpp.util.sasl.client import sasl_mech, Mech, \
@@ -63,7 +66,7 @@ class PLAIN(Mech):
if not self.security_settings['encrypted_plain']:
raise SASLCancelled('PLAIN with encryption')
- def process(self, challenge=b''):
+ def process(self, challenge: bytes_ = b'') -> bytes_:
authzid = self.credentials['authzid']
authcid = self.credentials['username']
password = self.credentials['password']
@@ -148,7 +151,7 @@ class CRAM(Mech):
required_credentials = {'username', 'password'}
security = {'encrypted', 'unencrypted_cram'}
- def setup(self, name):
+ def setup(self, name: str):
self.hash_name = name[5:]
self.hash = hash(self.hash_name)
if self.hash is None:
@@ -157,14 +160,14 @@ class CRAM(Mech):
if not self.security_settings['unencrypted_cram']:
raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name)
- def process(self, challenge=b''):
+ def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge:
return None
username = self.credentials['username']
password = self.credentials['password']
- mac = hmac.HMAC(key=password, digestmod=self.hash)
+ mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore
mac.update(challenge)
return username + b' ' + bytes(mac.hexdigest())
@@ -201,43 +204,42 @@ class SCRAM(Mech):
def HMAC(self, key, msg):
return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest()
- def Hi(self, text, salt, iterations):
- text = bytes(text)
- ui1 = self.HMAC(text, salt + b'\0\0\0\01')
+ def Hi(self, text: str, salt: bytes_, iterations: int):
+ text_enc = bytes(text)
+ ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01')
ui = ui1
for i in range(iterations - 1):
- ui1 = self.HMAC(text, ui1)
+ ui1 = self.HMAC(text_enc, ui1)
ui = XOR(ui, ui1)
return ui
- def H(self, text):
+ def H(self, text: str) -> bytes_:
return self.hash(text).digest()
- def saslname(self, value):
- value = value.decode("utf-8")
- escaped = []
+ def saslname(self, value_b: bytes_) -> bytes_:
+ value = value_b.decode("utf-8")
+ escaped: List[str] = []
for char in value:
if char == ',':
- escaped += b'=2C'
+ escaped.append('=2C')
elif char == '=':
- escaped += b'=3D'
+ escaped.append('=3D')
else:
- escaped += char
+ escaped.append(char)
return "".join(escaped).encode("utf-8")
- def parse(self, challenge):
+ def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]:
items = {}
for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]:
items[key] = value
return items
- def process(self, challenge=b''):
+ def process(self, challenge: bytes_ = b''):
steps = [self.process_1, self.process_2, self.process_3]
return steps[self.step](challenge)
- def process_1(self, challenge):
+ def process_1(self, challenge: bytes_) -> bytes_:
self.step = 1
- data = {}
self.cnonce = bytes(('%s' % random.random())[2:])
@@ -263,7 +265,7 @@ class SCRAM(Mech):
return self.client_first_message
- def process_2(self, challenge):
+ def process_2(self, challenge: bytes_) -> bytes_:
self.step = 2
data = self.parse(challenge)
@@ -304,7 +306,7 @@ class SCRAM(Mech):
return client_final_message
- def process_3(self, challenge):
+ def process_3(self, challenge: bytes_) -> bytes_:
data = self.parse(challenge)
verifier = data.get(b'v', None)
error = data.get(b'e', 'Unknown error')
@@ -345,17 +347,16 @@ class DIGEST(Mech):
self.cnonce = b''
self.nonce_count = 1
- def parse(self, challenge=b''):
- data = {}
+ def parse(self, challenge: bytes_ = b''):
+ data: Dict[str, bytes_] = {}
var_name = b''
var_value = b''
# States: var, new_var, end, quote, escaped_quote
state = 'var'
-
- for char in challenge:
- char = bytes([char])
+ for char_int in challenge:
+ char = bytes_([char_int])
if state == 'var':
if char.isspace():
@@ -401,14 +402,14 @@ class DIGEST(Mech):
state = 'var'
return data
- def MAC(self, key, seq, msg):
+ def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_:
mac = hmac.HMAC(key=key, digestmod=self.hash)
seqnum = num_to_bytes(seq)
mac.update(seqnum)
mac.update(msg)
return mac.digest()[:10] + b'\x00\x01' + seqnum
- def A1(self):
+ def A1(self) -> bytes_:
username = self.credentials['username']
password = self.credentials['password']
authzid = self.credentials['authzid']
@@ -423,13 +424,13 @@ class DIGEST(Mech):
return bytes(a1)
- def A2(self, prefix=b''):
+ def A2(self, prefix: bytes_ = b'') -> bytes_:
a2 = prefix + b':' + self.digest_uri()
if self.qop in (b'auth-int', b'auth-conf'):
a2 += b':00000000000000000000000000000000'
return bytes(a2)
- def response(self, prefix=b''):
+ def response(self, prefix: bytes_ = b'') -> bytes_:
nc = bytes('%08x' % self.nonce_count)
a1 = bytes(self.hash(self.A1()).hexdigest().lower())
@@ -439,7 +440,7 @@ class DIGEST(Mech):
return bytes(self.hash(a1 + b':' + s).hexdigest().lower())
- def digest_uri(self):
+ def digest_uri(self) -> bytes_:
serv_type = self.credentials['service']
serv_name = self.credentials['service-name']
host = self.credentials['host']
@@ -449,7 +450,7 @@ class DIGEST(Mech):
uri += b'/' + serv_name
return uri
- def respond(self):
+ def respond(self) -> bytes_:
data = {
'username': quote(self.credentials['username']),
'authzid': quote(self.credentials['authzid']),
@@ -469,7 +470,7 @@ class DIGEST(Mech):
resp += b',' + bytes(key) + b'=' + bytes(value)
return resp[1:]
- def process(self, challenge=b''):
+ def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge:
if self.cnonce and self.nonce and self.nonce_count and self.qop:
self.nonce_count += 1
@@ -480,6 +481,7 @@ class DIGEST(Mech):
if 'rspauth' in data:
if data['rspauth'] != self.response():
raise SASLMutualAuthFailed()
+ return None
else:
self.nonce_count = 1
self.cnonce = bytes('%s' % random.random())[2:]