diff options
Diffstat (limited to 'slixmpp/xmlstream/xmlstream.py')
-rw-r--r-- | slixmpp/xmlstream/xmlstream.py | 101 |
1 files changed, 80 insertions, 21 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 30f99071..19c4ddcc 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -15,6 +15,7 @@ from typing import ( Coroutine, Callable, Iterator, + Iterable, List, Optional, Set, @@ -31,8 +32,10 @@ import functools import logging import socket as Socket import ssl -import weakref import uuid +import warnings +import weakref +import collections from contextlib import contextmanager import xml.etree.ElementTree as ET @@ -46,6 +49,7 @@ from asyncio import ( iscoroutinefunction, wait, ) +from pathlib import Path from slixmpp.types import FilterString from slixmpp.xmlstream.tostring import tostring @@ -74,6 +78,15 @@ class NotConnectedError(Exception): """ +class InvalidCABundle(Exception): + """ + Exception raised when the CA Bundle file hasn't been found. + """ + + def __init__(self, path: Optional[Union[Path, Iterable[Path]]]): + self.path = path + + _T = TypeVar('_T', str, ElementBase, StanzaBase) @@ -161,7 +174,7 @@ class XMLStream(asyncio.BaseProtocol): #: #: On Mac OS X, certificates in the system keyring will #: be consulted, even if they are not in the provided file. - ca_certs: Optional[str] + ca_certs: Optional[Union[Path, Iterable[Path]]] #: Path to a file containing a client certificate to use for #: authenticating via SASL EXTERNAL. If set, there must also @@ -449,7 +462,7 @@ class XMLStream(asyncio.BaseProtocol): if self._connect_loop_wait > 0: self.event('reconnect_delay', self._connect_loop_wait) - await asyncio.sleep(self._connect_loop_wait, loop=self.loop) + await asyncio.sleep(self._connect_loop_wait) record = await self._pick_dns_answer(self.default_domain) if record is not None: @@ -480,16 +493,11 @@ class XMLStream(asyncio.BaseProtocol): except Socket.gaierror as e: self.event('connection_failed', 'No DNS record available for %s' % self.default_domain) + self.reschedule_connection_attempt() except OSError as e: log.debug('Connection failed: %s', e) self.event("connection_failed", e) - if self._current_connection_attempt is None: - return - self._connect_loop_wait = self._connect_loop_wait * 2 + 1 - self._current_connection_attempt = asyncio.ensure_future( - self._connect_routine(), - loop=self.loop, - ) + self.reschedule_connection_attempt() def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None: """Process all the available XMPP events (receiving or sending data on the @@ -497,17 +505,27 @@ class XMLStream(asyncio.BaseProtocol): timers, handling signal events, etc). If timeout is None, this function will run forever. If timeout is a number, this function will return after the given time in seconds. + + Will be removed in slixmpp 1.9.0 + + :deprecated: 1.8.0 """ + warnings.warn( + 'This function will be removed in slixmpp 1.9 and above.' + ' Use the asyncio normal functions instead.', + category=DeprecationWarning, + stacklevel=2, + ) if timeout is None: if forever: self.loop.run_forever() else: self.loop.run_until_complete(self.disconnected) else: - tasks: List[Future] = [asyncio.sleep(timeout, loop=self.loop)] + tasks: List[Awaitable] = [asyncio.sleep(timeout)] if not forever: tasks.append(self.disconnected) - self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) + self.loop.run_until_complete(asyncio.wait(tasks)) def init_parser(self) -> None: """init the XML parser. The parser must always be reset for each new @@ -556,7 +574,7 @@ class XMLStream(asyncio.BaseProtocol): stream=self, top_level=True, open_only=True)) - self.start_stream_handler(self.xml_root) + self.start_stream_handler(self.xml_root) # type:ignore self.xml_depth += 1 if event == 'end': self.xml_depth -= 1 @@ -615,6 +633,20 @@ class XMLStream(asyncio.BaseProtocol): self._set_disconnected_future() self.event("disconnected", self.disconnect_reason or exception) + def reschedule_connection_attempt(self) -> None: + """ + Increase the exponential back-off and initate another background + _connect_routine call to connect to the server. + """ + # abort if there is no ongoing connection attempt + if self._current_connection_attempt is None: + return + self._connect_loop_wait = min(300, self._connect_loop_wait * 2 + 1) + self._current_connection_attempt = asyncio.ensure_future( + self._connect_routine(), + loop=self.loop, + ) + def cancel_connection_attempt(self) -> None: """ Immediately cancel the current create_connection() Future. @@ -715,7 +747,7 @@ class XMLStream(asyncio.BaseProtocol): log.debug("reconnecting...") async def handler(event: Any) -> None: # We yield here to allow synchronous handlers to work first - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0) self.connect() self.add_event_handler('disconnected', handler, disposable=True) self.disconnect(wait, reason) @@ -759,8 +791,26 @@ class XMLStream(asyncio.BaseProtocol): log.debug('Loaded cert file %s and key file %s', self.certfile, self.keyfile) if self.ca_certs is not None: + ca_cert: Optional[Path] = None + # XXX: Compat before d733c54518. + if isinstance(self.ca_certs, str): + self.ca_certs = Path(self.ca_certs) + if isinstance(self.ca_certs, Path): + if self.ca_certs.is_file(): + ca_cert = self.ca_certs + else: + for bundle in self.ca_certs: + if bundle.is_file(): + ca_cert = bundle + break + if ca_cert is None and \ + isinstance(self.ca_certs, (Path, collections.abc.Iterable)): + raise InvalidCABundle(self.ca_certs) + self.ssl_context.verify_mode = ssl.CERT_REQUIRED - self.ssl_context.load_verify_locations(cafile=self.ca_certs) + self.ssl_context.load_verify_locations(cafile=ca_cert) + else: + self.ssl_context.set_default_verify_paths() return self.ssl_context @@ -1202,7 +1252,7 @@ class XMLStream(asyncio.BaseProtocol): else: self.send_raw(data) - async def run_filters(self) -> NoReturn: + async def run_filters(self) -> None: """ Background loop that processes stanzas to send. """ @@ -1217,7 +1267,7 @@ class XMLStream(asyncio.BaseProtocol): already_run_filters.add(filter) if iscoroutinefunction(filter): filter = cast(AsyncFilter, filter) - task = asyncio.create_task(filter(data)) + task = asyncio.create_task(filter(data)) # type:ignore completed, pending = await wait( {task}, timeout=1, @@ -1258,6 +1308,9 @@ class XMLStream(asyncio.BaseProtocol): self.send_raw(data) except ContinueQueue as exc: log.debug('Stanza in send queue not sent: %s', exc) + except asyncio.CancelledError: + log.debug('Send coroutine received cancel(), stopping') + return except Exception: log.error('Exception raised in send queue:', exc_info=True) self.waiting_queue.task_done() @@ -1278,10 +1331,16 @@ class XMLStream(asyncio.BaseProtocol): # Avoid circular imports from slixmpp.stanza.rootstanza import RootStanza from slixmpp.stanza import Iq, Handshake - passthrough = ( - (isinstance(data, Iq) and data.get_plugin('bind', check=True)) - or isinstance(data, Handshake) - ) + + passthrough = False + if isinstance(data, Iq): + if data.get_plugin('bind', check=True): + passthrough = True + elif data.get_plugin('session', check=True): + passthrough = True + elif isinstance(data, Handshake): + passthrough = True + if isinstance(data, (RootStanza, str)) and not passthrough: self.__queued_stanzas.append((data, use_filters)) log.debug('NOT SENT: %s %s', type(data), data) |