diff options
Diffstat (limited to 'slixmpp/xmlstream/xmlstream.py')
-rw-r--r-- | slixmpp/xmlstream/xmlstream.py | 49 |
1 files changed, 34 insertions, 15 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 7b362203..18464ccd 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -35,6 +35,7 @@ import ssl import uuid import warnings import weakref +import collections from contextlib import contextmanager import xml.etree.ElementTree as ET @@ -82,7 +83,7 @@ class InvalidCABundle(Exception): Exception raised when the CA Bundle file hasn't been found. """ - def __init__(self, path: Optional[Path]): + def __init__(self, path: Optional[Union[Path, Iterable[Path]]]): self.path = path @@ -492,16 +493,11 @@ class XMLStream(asyncio.BaseProtocol): except Socket.gaierror as e: self.event('connection_failed', 'No DNS record available for %s' % self.default_domain) + self.reschedule_connection_attempt() except OSError as e: log.debug('Connection failed: %s', e) self.event("connection_failed", e) - if self._current_connection_attempt is None: - return - self._connect_loop_wait = self._connect_loop_wait * 2 + 1 - self._current_connection_attempt = asyncio.ensure_future( - self._connect_routine(), - loop=self.loop, - ) + self.reschedule_connection_attempt() def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None: """Process all the available XMPP events (receiving or sending data on the @@ -526,7 +522,7 @@ class XMLStream(asyncio.BaseProtocol): else: self.loop.run_until_complete(self.disconnected) else: - tasks: List[Future] = [asyncio.sleep(timeout)] + tasks: List[Awaitable] = [asyncio.sleep(timeout)] if not forever: tasks.append(self.disconnected) self.loop.run_until_complete(asyncio.wait(tasks)) @@ -637,6 +633,20 @@ class XMLStream(asyncio.BaseProtocol): self._set_disconnected_future() self.event("disconnected", self.disconnect_reason or exception) + def reschedule_connection_attempt(self) -> None: + """ + Increase the exponential back-off and initate another background + _connect_routine call to connect to the server. + """ + # abort if there is no ongoing connection attempt + if self._current_connection_attempt is None: + return + self._connect_loop_wait = min(300, self._connect_loop_wait * 2 + 1) + self._current_connection_attempt = asyncio.ensure_future( + self._connect_routine(), + loop=self.loop, + ) + def cancel_connection_attempt(self) -> None: """ Immediately cancel the current create_connection() Future. @@ -793,11 +803,14 @@ class XMLStream(asyncio.BaseProtocol): if bundle.is_file(): ca_cert = bundle break - if ca_cert is None: - raise InvalidCABundle(ca_cert) + if ca_cert is None and \ + isinstance(self.ca_certs, (Path, collections.abc.Iterable)): + raise InvalidCABundle(self.ca_certs) self.ssl_context.verify_mode = ssl.CERT_REQUIRED self.ssl_context.load_verify_locations(cafile=ca_cert) + else: + self.ssl_context.set_default_verify_paths() return self.ssl_context @@ -1318,10 +1331,16 @@ class XMLStream(asyncio.BaseProtocol): # Avoid circular imports from slixmpp.stanza.rootstanza import RootStanza from slixmpp.stanza import Iq, Handshake - passthrough = ( - (isinstance(data, Iq) and data.get_plugin('bind', check=True)) - or isinstance(data, Handshake) - ) + + passthrough = False + if isinstance(data, Iq): + if data.get_plugin('bind', check=True): + passthrough = True + elif data.get_plugin('session', check=True): + passthrough = True + elif isinstance(data, Handshake): + passthrough = True + if isinstance(data, (RootStanza, str)) and not passthrough: self.__queued_stanzas.append((data, use_filters)) log.debug('NOT SENT: %s %s', type(data), data) |