summaryrefslogtreecommitdiff
path: root/slixmpp/util/sasl/mechanisms.py
diff options
context:
space:
mode:
Diffstat (limited to 'slixmpp/util/sasl/mechanisms.py')
-rw-r--r--slixmpp/util/sasl/mechanisms.py68
1 files changed, 35 insertions, 33 deletions
diff --git a/slixmpp/util/sasl/mechanisms.py b/slixmpp/util/sasl/mechanisms.py
index 53f39395..d53caec8 100644
--- a/slixmpp/util/sasl/mechanisms.py
+++ b/slixmpp/util/sasl/mechanisms.py
@@ -11,6 +11,9 @@ import hmac
import random
from base64 import b64encode, b64decode
+from typing import List, Dict, Optional
+
+bytes_ = bytes
from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes
from slixmpp.util.sasl.client import sasl_mech, Mech, \
@@ -63,7 +66,7 @@ class PLAIN(Mech):
if not self.security_settings['encrypted_plain']:
raise SASLCancelled('PLAIN with encryption')
- def process(self, challenge=b''):
+ def process(self, challenge: bytes_ = b'') -> bytes_:
authzid = self.credentials['authzid']
authcid = self.credentials['username']
password = self.credentials['password']
@@ -148,7 +151,7 @@ class CRAM(Mech):
required_credentials = {'username', 'password'}
security = {'encrypted', 'unencrypted_cram'}
- def setup(self, name):
+ def setup(self, name: str):
self.hash_name = name[5:]
self.hash = hash(self.hash_name)
if self.hash is None:
@@ -157,14 +160,14 @@ class CRAM(Mech):
if not self.security_settings['unencrypted_cram']:
raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name)
- def process(self, challenge=b''):
+ def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge:
return None
username = self.credentials['username']
password = self.credentials['password']
- mac = hmac.HMAC(key=password, digestmod=self.hash)
+ mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore
mac.update(challenge)
return username + b' ' + bytes(mac.hexdigest())
@@ -201,43 +204,42 @@ class SCRAM(Mech):
def HMAC(self, key, msg):
return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest()
- def Hi(self, text, salt, iterations):
- text = bytes(text)
- ui1 = self.HMAC(text, salt + b'\0\0\0\01')
+ def Hi(self, text: str, salt: bytes_, iterations: int):
+ text_enc = bytes(text)
+ ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01')
ui = ui1
for i in range(iterations - 1):
- ui1 = self.HMAC(text, ui1)
+ ui1 = self.HMAC(text_enc, ui1)
ui = XOR(ui, ui1)
return ui
- def H(self, text):
+ def H(self, text: str) -> bytes_:
return self.hash(text).digest()
- def saslname(self, value):
- value = value.decode("utf-8")
- escaped = []
+ def saslname(self, value_b: bytes_) -> bytes_:
+ value = value_b.decode("utf-8")
+ escaped: List[str] = []
for char in value:
if char == ',':
- escaped += b'=2C'
+ escaped.append('=2C')
elif char == '=':
- escaped += b'=3D'
+ escaped.append('=3D')
else:
- escaped += char
+ escaped.append(char)
return "".join(escaped).encode("utf-8")
- def parse(self, challenge):
+ def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]:
items = {}
for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]:
items[key] = value
return items
- def process(self, challenge=b''):
+ def process(self, challenge: bytes_ = b''):
steps = [self.process_1, self.process_2, self.process_3]
return steps[self.step](challenge)
- def process_1(self, challenge):
+ def process_1(self, challenge: bytes_) -> bytes_:
self.step = 1
- data = {}
self.cnonce = bytes(('%s' % random.random())[2:])
@@ -263,7 +265,7 @@ class SCRAM(Mech):
return self.client_first_message
- def process_2(self, challenge):
+ def process_2(self, challenge: bytes_) -> bytes_:
self.step = 2
data = self.parse(challenge)
@@ -304,7 +306,7 @@ class SCRAM(Mech):
return client_final_message
- def process_3(self, challenge):
+ def process_3(self, challenge: bytes_) -> bytes_:
data = self.parse(challenge)
verifier = data.get(b'v', None)
error = data.get(b'e', 'Unknown error')
@@ -345,17 +347,16 @@ class DIGEST(Mech):
self.cnonce = b''
self.nonce_count = 1
- def parse(self, challenge=b''):
- data = {}
+ def parse(self, challenge: bytes_ = b''):
+ data: Dict[str, bytes_] = {}
var_name = b''
var_value = b''
# States: var, new_var, end, quote, escaped_quote
state = 'var'
-
- for char in challenge:
- char = bytes([char])
+ for char_int in challenge:
+ char = bytes_([char_int])
if state == 'var':
if char.isspace():
@@ -401,14 +402,14 @@ class DIGEST(Mech):
state = 'var'
return data
- def MAC(self, key, seq, msg):
+ def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_:
mac = hmac.HMAC(key=key, digestmod=self.hash)
seqnum = num_to_bytes(seq)
mac.update(seqnum)
mac.update(msg)
return mac.digest()[:10] + b'\x00\x01' + seqnum
- def A1(self):
+ def A1(self) -> bytes_:
username = self.credentials['username']
password = self.credentials['password']
authzid = self.credentials['authzid']
@@ -423,13 +424,13 @@ class DIGEST(Mech):
return bytes(a1)
- def A2(self, prefix=b''):
+ def A2(self, prefix: bytes_ = b'') -> bytes_:
a2 = prefix + b':' + self.digest_uri()
if self.qop in (b'auth-int', b'auth-conf'):
a2 += b':00000000000000000000000000000000'
return bytes(a2)
- def response(self, prefix=b''):
+ def response(self, prefix: bytes_ = b'') -> bytes_:
nc = bytes('%08x' % self.nonce_count)
a1 = bytes(self.hash(self.A1()).hexdigest().lower())
@@ -439,7 +440,7 @@ class DIGEST(Mech):
return bytes(self.hash(a1 + b':' + s).hexdigest().lower())
- def digest_uri(self):
+ def digest_uri(self) -> bytes_:
serv_type = self.credentials['service']
serv_name = self.credentials['service-name']
host = self.credentials['host']
@@ -449,7 +450,7 @@ class DIGEST(Mech):
uri += b'/' + serv_name
return uri
- def respond(self):
+ def respond(self) -> bytes_:
data = {
'username': quote(self.credentials['username']),
'authzid': quote(self.credentials['authzid']),
@@ -469,7 +470,7 @@ class DIGEST(Mech):
resp += b',' + bytes(key) + b'=' + bytes(value)
return resp[1:]
- def process(self, challenge=b''):
+ def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge:
if self.cnonce and self.nonce and self.nonce_count and self.qop:
self.nonce_count += 1
@@ -480,6 +481,7 @@ class DIGEST(Mech):
if 'rspauth' in data:
if data['rspauth'] != self.response():
raise SASLMutualAuthFailed()
+ return None
else:
self.nonce_count = 1
self.cnonce = bytes('%s' % random.random())[2:]