From d733c54518cda652ec3c753c2483d925b20eae57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maxime=20=E2=80=9Cpep=E2=80=9D=20Buquet?= Date: Tue, 28 Dec 2021 19:50:20 +0100 Subject: Allow Xmlstream.ca_certs to be an iterable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Maxime “pep” Buquet --- slixmpp/xmlstream/xmlstream.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 7c4283f2..fd0269da 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -15,6 +15,7 @@ from typing import ( Coroutine, Callable, Iterator, + Iterable, List, Optional, Set, @@ -33,7 +34,6 @@ import socket as Socket import ssl import weakref import uuid -from pathlib import Path from contextlib import contextmanager import xml.etree.ElementTree as ET @@ -47,6 +47,7 @@ from asyncio import ( iscoroutinefunction, wait, ) +from pathlib import Path from slixmpp.types import FilterString from slixmpp.xmlstream.tostring import tostring @@ -75,6 +76,15 @@ class NotConnectedError(Exception): """ +class InvalidCABundle(Exception): + """ + Exception raised when the CA Bundle file hasn't been found. + """ + + def __init__(self, path: Optional[Path]): + self.path = path + + _T = TypeVar('_T', str, ElementBase, StanzaBase) @@ -162,7 +172,7 @@ class XMLStream(asyncio.BaseProtocol): #: #: On Mac OS X, certificates in the system keyring will #: be consulted, even if they are not in the provided file. - ca_certs: Optional[Path] + ca_certs: Optional[Union[Path, Iterable[Path]]] #: Path to a file containing a client certificate to use for #: authenticating via SASL EXTERNAL. If set, there must also @@ -760,8 +770,20 @@ class XMLStream(asyncio.BaseProtocol): log.debug('Loaded cert file %s and key file %s', self.certfile, self.keyfile) if self.ca_certs is not None: + ca_cert: Optional[Path] = None + if isinstance(self.ca_certs, Path): + if self.ca_certs.is_file(): + ca_cert = self.ca_certs + else: + for bundle in self.ca_certs: + if bundle.is_file(): + ca_cert = bundle + break + if ca_cert is None: + raise InvalidCABundle(ca_cert) + self.ssl_context.verify_mode = ssl.CERT_REQUIRED - self.ssl_context.load_verify_locations(cafile=self.ca_certs) + self.ssl_context.load_verify_locations(cafile=ca_cert) return self.ssl_context -- cgit v1.2.3