summaryrefslogtreecommitdiff
path: root/slixmpp/xmlstream/xmlstream.py
diff options
context:
space:
mode:
Diffstat (limited to 'slixmpp/xmlstream/xmlstream.py')
-rw-r--r--slixmpp/xmlstream/xmlstream.py101
1 files changed, 80 insertions, 21 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 30f99071..19c4ddcc 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -15,6 +15,7 @@ from typing import (
Coroutine,
Callable,
Iterator,
+ Iterable,
List,
Optional,
Set,
@@ -31,8 +32,10 @@ import functools
import logging
import socket as Socket
import ssl
-import weakref
import uuid
+import warnings
+import weakref
+import collections
from contextlib import contextmanager
import xml.etree.ElementTree as ET
@@ -46,6 +49,7 @@ from asyncio import (
iscoroutinefunction,
wait,
)
+from pathlib import Path
from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring
@@ -74,6 +78,15 @@ class NotConnectedError(Exception):
"""
+class InvalidCABundle(Exception):
+ """
+ Exception raised when the CA Bundle file hasn't been found.
+ """
+
+ def __init__(self, path: Optional[Union[Path, Iterable[Path]]]):
+ self.path = path
+
+
_T = TypeVar('_T', str, ElementBase, StanzaBase)
@@ -161,7 +174,7 @@ class XMLStream(asyncio.BaseProtocol):
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
- ca_certs: Optional[str]
+ ca_certs: Optional[Union[Path, Iterable[Path]]]
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
@@ -449,7 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
if self._connect_loop_wait > 0:
self.event('reconnect_delay', self._connect_loop_wait)
- await asyncio.sleep(self._connect_loop_wait, loop=self.loop)
+ await asyncio.sleep(self._connect_loop_wait)
record = await self._pick_dns_answer(self.default_domain)
if record is not None:
@@ -480,16 +493,11 @@ class XMLStream(asyncio.BaseProtocol):
except Socket.gaierror as e:
self.event('connection_failed',
'No DNS record available for %s' % self.default_domain)
+ self.reschedule_connection_attempt()
except OSError as e:
log.debug('Connection failed: %s', e)
self.event("connection_failed", e)
- if self._current_connection_attempt is None:
- return
- self._connect_loop_wait = self._connect_loop_wait * 2 + 1
- self._current_connection_attempt = asyncio.ensure_future(
- self._connect_routine(),
- loop=self.loop,
- )
+ self.reschedule_connection_attempt()
def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the
@@ -497,17 +505,27 @@ class XMLStream(asyncio.BaseProtocol):
timers, handling signal events, etc). If timeout is None, this
function will run forever. If timeout is a number, this function
will return after the given time in seconds.
+
+ Will be removed in slixmpp 1.9.0
+
+ :deprecated: 1.8.0
"""
+ warnings.warn(
+ 'This function will be removed in slixmpp 1.9 and above.'
+ ' Use the asyncio normal functions instead.',
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
if timeout is None:
if forever:
self.loop.run_forever()
else:
self.loop.run_until_complete(self.disconnected)
else:
- tasks: List[Future] = [asyncio.sleep(timeout, loop=self.loop)]
+ tasks: List[Awaitable] = [asyncio.sleep(timeout)]
if not forever:
tasks.append(self.disconnected)
- self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
+ self.loop.run_until_complete(asyncio.wait(tasks))
def init_parser(self) -> None:
"""init the XML parser. The parser must always be reset for each new
@@ -556,7 +574,7 @@ class XMLStream(asyncio.BaseProtocol):
stream=self,
top_level=True,
open_only=True))
- self.start_stream_handler(self.xml_root)
+ self.start_stream_handler(self.xml_root) # type:ignore
self.xml_depth += 1
if event == 'end':
self.xml_depth -= 1
@@ -615,6 +633,20 @@ class XMLStream(asyncio.BaseProtocol):
self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception)
+ def reschedule_connection_attempt(self) -> None:
+ """
+ Increase the exponential back-off and initate another background
+ _connect_routine call to connect to the server.
+ """
+ # abort if there is no ongoing connection attempt
+ if self._current_connection_attempt is None:
+ return
+ self._connect_loop_wait = min(300, self._connect_loop_wait * 2 + 1)
+ self._current_connection_attempt = asyncio.ensure_future(
+ self._connect_routine(),
+ loop=self.loop,
+ )
+
def cancel_connection_attempt(self) -> None:
"""
Immediately cancel the current create_connection() Future.
@@ -715,7 +747,7 @@ class XMLStream(asyncio.BaseProtocol):
log.debug("reconnecting...")
async def handler(event: Any) -> None:
# We yield here to allow synchronous handlers to work first
- await asyncio.sleep(0, loop=self.loop)
+ await asyncio.sleep(0)
self.connect()
self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason)
@@ -759,8 +791,26 @@ class XMLStream(asyncio.BaseProtocol):
log.debug('Loaded cert file %s and key file %s',
self.certfile, self.keyfile)
if self.ca_certs is not None:
+ ca_cert: Optional[Path] = None
+ # XXX: Compat before d733c54518.
+ if isinstance(self.ca_certs, str):
+ self.ca_certs = Path(self.ca_certs)
+ if isinstance(self.ca_certs, Path):
+ if self.ca_certs.is_file():
+ ca_cert = self.ca_certs
+ else:
+ for bundle in self.ca_certs:
+ if bundle.is_file():
+ ca_cert = bundle
+ break
+ if ca_cert is None and \
+ isinstance(self.ca_certs, (Path, collections.abc.Iterable)):
+ raise InvalidCABundle(self.ca_certs)
+
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
- self.ssl_context.load_verify_locations(cafile=self.ca_certs)
+ self.ssl_context.load_verify_locations(cafile=ca_cert)
+ else:
+ self.ssl_context.set_default_verify_paths()
return self.ssl_context
@@ -1202,7 +1252,7 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.send_raw(data)
- async def run_filters(self) -> NoReturn:
+ async def run_filters(self) -> None:
"""
Background loop that processes stanzas to send.
"""
@@ -1217,7 +1267,7 @@ class XMLStream(asyncio.BaseProtocol):
already_run_filters.add(filter)
if iscoroutinefunction(filter):
filter = cast(AsyncFilter, filter)
- task = asyncio.create_task(filter(data))
+ task = asyncio.create_task(filter(data)) # type:ignore
completed, pending = await wait(
{task},
timeout=1,
@@ -1258,6 +1308,9 @@ class XMLStream(asyncio.BaseProtocol):
self.send_raw(data)
except ContinueQueue as exc:
log.debug('Stanza in send queue not sent: %s', exc)
+ except asyncio.CancelledError:
+ log.debug('Send coroutine received cancel(), stopping')
+ return
except Exception:
log.error('Exception raised in send queue:', exc_info=True)
self.waiting_queue.task_done()
@@ -1278,10 +1331,16 @@ class XMLStream(asyncio.BaseProtocol):
# Avoid circular imports
from slixmpp.stanza.rootstanza import RootStanza
from slixmpp.stanza import Iq, Handshake
- passthrough = (
- (isinstance(data, Iq) and data.get_plugin('bind', check=True))
- or isinstance(data, Handshake)
- )
+
+ passthrough = False
+ if isinstance(data, Iq):
+ if data.get_plugin('bind', check=True):
+ passthrough = True
+ elif data.get_plugin('session', check=True):
+ passthrough = True
+ elif isinstance(data, Handshake):
+ passthrough = True
+
if isinstance(data, (RootStanza, str)) and not passthrough:
self.__queued_stanzas.append((data, use_filters))
log.debug('NOT SENT: %s %s', type(data), data)