summaryrefslogtreecommitdiff
path: root/slixmpp/xmlstream/xmlstream.py
diff options
context:
space:
mode:
Diffstat (limited to 'slixmpp/xmlstream/xmlstream.py')
-rw-r--r--slixmpp/xmlstream/xmlstream.py49
1 files changed, 34 insertions, 15 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 7b362203..18464ccd 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -35,6 +35,7 @@ import ssl
import uuid
import warnings
import weakref
+import collections
from contextlib import contextmanager
import xml.etree.ElementTree as ET
@@ -82,7 +83,7 @@ class InvalidCABundle(Exception):
Exception raised when the CA Bundle file hasn't been found.
"""
- def __init__(self, path: Optional[Path]):
+ def __init__(self, path: Optional[Union[Path, Iterable[Path]]]):
self.path = path
@@ -492,16 +493,11 @@ class XMLStream(asyncio.BaseProtocol):
except Socket.gaierror as e:
self.event('connection_failed',
'No DNS record available for %s' % self.default_domain)
+ self.reschedule_connection_attempt()
except OSError as e:
log.debug('Connection failed: %s', e)
self.event("connection_failed", e)
- if self._current_connection_attempt is None:
- return
- self._connect_loop_wait = self._connect_loop_wait * 2 + 1
- self._current_connection_attempt = asyncio.ensure_future(
- self._connect_routine(),
- loop=self.loop,
- )
+ self.reschedule_connection_attempt()
def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the
@@ -526,7 +522,7 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.loop.run_until_complete(self.disconnected)
else:
- tasks: List[Future] = [asyncio.sleep(timeout)]
+ tasks: List[Awaitable] = [asyncio.sleep(timeout)]
if not forever:
tasks.append(self.disconnected)
self.loop.run_until_complete(asyncio.wait(tasks))
@@ -637,6 +633,20 @@ class XMLStream(asyncio.BaseProtocol):
self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception)
+ def reschedule_connection_attempt(self) -> None:
+ """
+ Increase the exponential back-off and initate another background
+ _connect_routine call to connect to the server.
+ """
+ # abort if there is no ongoing connection attempt
+ if self._current_connection_attempt is None:
+ return
+ self._connect_loop_wait = min(300, self._connect_loop_wait * 2 + 1)
+ self._current_connection_attempt = asyncio.ensure_future(
+ self._connect_routine(),
+ loop=self.loop,
+ )
+
def cancel_connection_attempt(self) -> None:
"""
Immediately cancel the current create_connection() Future.
@@ -793,11 +803,14 @@ class XMLStream(asyncio.BaseProtocol):
if bundle.is_file():
ca_cert = bundle
break
- if ca_cert is None:
- raise InvalidCABundle(ca_cert)
+ if ca_cert is None and \
+ isinstance(self.ca_certs, (Path, collections.abc.Iterable)):
+ raise InvalidCABundle(self.ca_certs)
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.load_verify_locations(cafile=ca_cert)
+ else:
+ self.ssl_context.set_default_verify_paths()
return self.ssl_context
@@ -1318,10 +1331,16 @@ class XMLStream(asyncio.BaseProtocol):
# Avoid circular imports
from slixmpp.stanza.rootstanza import RootStanza
from slixmpp.stanza import Iq, Handshake
- passthrough = (
- (isinstance(data, Iq) and data.get_plugin('bind', check=True))
- or isinstance(data, Handshake)
- )
+
+ passthrough = False
+ if isinstance(data, Iq):
+ if data.get_plugin('bind', check=True):
+ passthrough = True
+ elif data.get_plugin('session', check=True):
+ passthrough = True
+ elif isinstance(data, Handshake):
+ passthrough = True
+
if isinstance(data, (RootStanza, str)) and not passthrough:
self.__queued_stanzas.append((data, use_filters))
log.debug('NOT SENT: %s %s', type(data), data)