summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--slixmpp/xmlstream/xmlstream.py470
1 files changed, 302 insertions, 168 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 45d814fd..dbebf2c8 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -9,17 +9,24 @@
# :license: MIT, see LICENSE for more details
from typing import (
Any,
+ Dict,
+ Awaitable,
+ Generator,
Coroutine,
Callable,
- Iterable,
Iterator,
List,
Optional,
Set,
Union,
Tuple,
+ TypeVar,
+ NoReturn,
+ Type,
+ cast,
)
+import asyncio
import functools
import logging
import socket as Socket
@@ -27,30 +34,66 @@ import ssl
import weakref
import uuid
-import asyncio
-from asyncio import iscoroutinefunction, wait, Future
from contextlib import contextmanager
import xml.etree.ElementTree as ET
+from asyncio import (
+ AbstractEventLoop,
+ BaseTransport,
+ Future,
+ Task,
+ TimerHandle,
+ Transport,
+ iscoroutinefunction,
+ wait,
+)
-from slixmpp.xmlstream import tostring
+from slixmpp.types import FilterString
+from slixmpp.xmlstream.tostring import tostring
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
from slixmpp.xmlstream.resolver import resolve, default_resolver
+from slixmpp.xmlstream.handler.base import BaseHandler
+
+T = TypeVar('T')
#: The time in seconds to wait before timing out waiting for response stanzas.
RESPONSE_TIMEOUT = 30
log = logging.getLogger(__name__)
+
+
class ContinueQueue(Exception):
"""
Exception raised in the send queue to "continue" from within an inner loop
"""
+
class NotConnectedError(Exception):
"""
Raised when we try to send something over the wire but we are not
connected.
"""
+
+_T = TypeVar('_T', str, ElementBase, StanzaBase)
+
+
+SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
+AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]
+
+
+Filter = Union[
+ SyncFilter,
+ AsyncFilter,
+]
+
+_FiltersDict = Dict[str, List[Filter]]
+
+Handler = Callable[[Any], Union[
+ Any,
+ Coroutine[Any, Any, Any]
+]]
+
+
class XMLStream(asyncio.BaseProtocol):
"""
An XML stream connection manager and event dispatcher.
@@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0.
"""
- def __init__(self, host='', port=0):
- # The asyncio.Transport object provided by the connection_made()
- # callback when we are connected
- self.transport = None
+ transport: Optional[Transport]
- # The socket that is used internally by the transport object
- self.socket = None
+ # The socket that is used internally by the transport object
+ socket: Optional[ssl.SSLSocket]
+
+ # The backoff of the connect routine (increases exponentially
+ # after each failure)
+ _connect_loop_wait: float
+
+ parser: Optional[ET.XMLPullParser]
+ xml_depth: int
+ xml_root: Optional[ET.Element]
+
+ force_starttls: Optional[bool]
+ disable_starttls: Optional[bool]
+
+ waiting_queue: asyncio.Queue[Tuple[Union[StanzaBase, str], bool]]
+
+ # A dict of {name: handle}
+ scheduled_events: Dict[str, TimerHandle]
+
+ ssl_context: ssl.SSLContext
+
+ # The event to trigger when the create_connection() succeeds. It can
+ # be "connected" or "tls_success" depending on the step we are at.
+ event_when_connected: str
+
+ #: The list of accepted ciphers, in OpenSSL Format.
+ #: It might be useful to override it for improved security
+ #: over the python defaults.
+ ciphers: Optional[str]
+
+ #: Path to a file containing certificates for verifying the
+ #: server SSL certificate. A non-``None`` value will trigger
+ #: certificate checking.
+ #:
+ #: .. note::
+ #:
+ #: On Mac OS X, certificates in the system keyring will
+ #: be consulted, even if they are not in the provided file.
+ ca_certs: Optional[str]
+
+ #: Path to a file containing a client certificate to use for
+ #: authenticating via SASL EXTERNAL. If set, there must also
+ #: be a corresponding `:attr:keyfile` value.
+ certfile: Optional[str]
+
+ #: Path to a file containing the private key for the selected
+ #: client certificate to use for authenticating via SASL EXTERNAL.
+ keyfile: Optional[str]
+
+ # The asyncio event loop
+ _loop: Optional[AbstractEventLoop]
+
+ #: The default port to return when querying DNS records.
+ default_port: int
+
+ #: The domain to try when querying DNS records.
+ default_domain: str
+
+ #: The expected name of the server, for validation.
+ _expected_server_name: str
+ _service_name: str
+
+ #: The desired, or actual, address of the connected server.
+ address: Tuple[str, int]
+
+ #: Enable connecting to the server directly over SSL, in
+ #: particular when the service provides two ports: one for
+ #: non-SSL traffic and another for SSL traffic.
+ use_ssl: bool
+
+ #: If set to ``True``, attempt to use IPv6.
+ use_ipv6: bool
- # The backoff of the connect routine (increases exponentially
- # after each failure)
+ #: If set to ``True``, allow using the ``dnspython`` DNS library
+ #: if available. If set to ``False``, the builtin DNS resolver
+ #: will be used, even if ``dnspython`` is installed.
+ use_aiodns: bool
+
+ #: Use CDATA for escaping instead of XML entities. Defaults
+ #: to ``False``.
+ use_cdata: bool
+
+ #: The default namespace of the stream content, not of the
+ #: stream wrapper it
+ default_ns: str
+
+ default_lang: Optional[str]
+ peer_default_lang: Optional[str]
+
+ #: The namespace of the enveloping stream element.
+ stream_ns: str
+
+ #: The default opening tag for the stream element.
+ stream_header: str
+
+ #: The default closing tag for the stream element.
+ stream_footer: str
+
+ #: If ``True``, periodically send a whitespace character over the
+ #: wire to keep the connection alive. Mainly useful for connections
+ #: traversing NAT.
+ whitespace_keepalive: bool
+
+ #: The default interval between keepalive signals when
+ #: :attr:`whitespace_keepalive` is enabled.
+ whitespace_keepalive_interval: int
+
+ #: Flag for controlling if the session can be considered ended
+ #: if the connection is terminated.
+ end_session_on_disconnect: bool
+
+ #: A mapping of XML namespaces to well-known prefixes.
+ namespace_map: dict
+
+ __root_stanza: List[Type[StanzaBase]]
+ __handlers: List[BaseHandler]
+ __event_handlers: Dict[str, List[Tuple[Handler, bool]]]
+ __filters: _FiltersDict
+
+ # Current connection attempt (Future)
+ _current_connection_attempt: Optional[Future[None]]
+
+ #: A list of DNS results that have not yet been tried.
+ _dns_answers: Optional[Iterator[Tuple[str, str, int]]]
+
+ #: The service name to check with DNS SRV records. For
+ #: example, setting this to ``'xmpp-client'`` would query the
+ #: ``_xmpp-client._tcp`` service.
+ dns_service: Optional[str]
+
+ #: The reason why we are disconnecting from the server
+ disconnect_reason: Optional[str]
+
+ #: An asyncio Future being done when the stream is disconnected.
+ disconnected: Future[bool]
+
+ # If the session has been started or not
+ _session_started: bool
+ # If we want to bypass the send() check (e.g. unit tests)
+ _always_send_everything: bool
+
+ _run_out_filters: Optional[Future]
+ __slow_tasks: List[Task]
+ __queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]
+
+ def __init__(self, host: str = '', port: int = 0):
+ self.transport = None
+ self.socket = None
self._connect_loop_wait = 0
self.parser = None
@@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
- # The event to trigger when the create_connection() succeeds. It can
- # be "connected" or "tls_success" depending on the step we are at.
self.event_when_connected = "connected"
- #: The list of accepted ciphers, in OpenSSL Format.
- #: It might be useful to override it for improved security
- #: over the python defaults.
self.ciphers = None
- #: Path to a file containing certificates for verifying the
- #: server SSL certificate. A non-``None`` value will trigger
- #: certificate checking.
- #:
- #: .. note::
- #:
- #: On Mac OS X, certificates in the system keyring will
- #: be consulted, even if they are not in the provided file.
self.ca_certs = None
- #: Path to a file containing a client certificate to use for
- #: authenticating via SASL EXTERNAL. If set, there must also
- #: be a corresponding `:attr:keyfile` value.
- self.certfile = None
-
- #: Path to a file containing the private key for the selected
- #: client certificate to use for authenticating via SASL EXTERNAL.
self.keyfile = None
- self._der_cert = None
-
- # The asyncio event loop
self._loop = None
- #: The default port to return when querying DNS records.
self.default_port = int(port)
-
- #: The domain to try when querying DNS records.
self.default_domain = ''
- #: The expected name of the server, for validation.
self._expected_server_name = ''
self._service_name = ''
- #: The desired, or actual, address of the connected server.
self.address = (host, int(port))
- #: Enable connecting to the server directly over SSL, in
- #: particular when the service provides two ports: one for
- #: non-SSL traffic and another for SSL traffic.
self.use_ssl = False
-
- #: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True
- #: If set to ``True``, allow using the ``dnspython`` DNS library
- #: if available. If set to ``False``, the builtin DNS resolver
- #: will be used, even if ``dnspython`` is installed.
self.use_aiodns = True
-
- #: Use CDATA for escaping instead of XML entities. Defaults
- #: to ``False``.
self.use_cdata = False
- #: The default namespace of the stream content, not of the
- #: stream wrapper itself.
self.default_ns = ''
self.default_lang = None
self.peer_default_lang = None
- #: The namespace of the enveloping stream element.
self.stream_ns = ''
-
- #: The default opening tag for the stream element.
self.stream_header = "<stream>"
-
- #: The default closing tag for the stream element.
self.stream_footer = "</stream>"
- #: If ``True``, periodically send a whitespace character over the
- #: wire to keep the connection alive. Mainly useful for connections
- #: traversing NAT.
self.whitespace_keepalive = True
-
- #: The default interval between keepalive signals when
- #: :attr:`whitespace_keepalive` is enabled.
self.whitespace_keepalive_interval = 300
- #: Flag for controlling if the session can be considered ended
- #: if the connection is terminated.
self.end_session_on_disconnect = True
-
- #: A mapping of XML namespaces to well-known prefixes.
self.namespace_map = {StanzaBase.xml_ns: 'xml'}
self.__root_stanza = []
self.__handlers = []
self.__event_handlers = {}
- self.__filters = {'in': [], 'out': [], 'out_sync': []}
+ self.__filters = {
+ 'in': [], 'out': [], 'out_sync': []
+ }
- # Current connection attempt (Future)
self._current_connection_attempt = None
- #: A list of DNS results that have not yet been tried.
- self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
-
- #: The service name to check with DNS SRV records. For
- #: example, setting this to ``'xmpp-client'`` would query the
- #: ``_xmpp-client._tcp`` service.
+ self._dns_answers = None
self.dns_service = None
- #: The reason why we are disconnecting from the server
self.disconnect_reason = None
-
- #: An asyncio Future being done when the stream is disconnected.
- self.disconnected: Future = Future()
-
- # If the session has been started or not
+ self.disconnected = Future()
self._session_started = False
- # If we want to bypass the send() check (e.g. unit tests)
self._always_send_everything = False
self.add_event_handler('disconnected', self._remove_schedules)
@@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('session_start', self._set_session_start)
self.add_event_handler('session_resumed', self._set_session_start)
- self._run_out_filters: Optional[Future] = None
- self.__slow_tasks: List[Future] = []
- self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = []
+ self._run_out_filters = None
+ self.__slow_tasks = []
+ self.__queued_stanzas = []
@property
- def loop(self):
+ def loop(self) -> AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_event_loop()
return self._loop
@loop.setter
- def loop(self, value):
+ def loop(self, value: AbstractEventLoop) -> None:
self._loop = value
- def new_id(self):
+ def new_id(self) -> str:
"""Generate and return a new stream ID in hexadecimal form.
Many stanzas, handlers, or matchers may require unique
@@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return uuid.uuid4().hex
- def _set_session_start(self, event):
+ def _set_session_start(self, event: Any) -> None:
"""
On session start, queue all pending stanzas to be sent.
"""
@@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol):
self.waiting_queue.put_nowait(stanza)
self.__queued_stanzas = []
- def _set_disconnected(self, event):
+ def _set_disconnected(self, event: Any) -> None:
self._session_started = False
- def _set_disconnected_future(self):
+ def _set_disconnected_future(self) -> None:
"""Set the self.disconnected future on disconnect"""
if not self.disconnected.done():
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
- def connect(self, host='', port=0, use_ssl=False,
- force_starttls=True, disable_starttls=False):
+ def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
+ force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
"""Create a new socket and connect to the server.
:param host: The name of the desired server for the connection.
@@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
- async def _connect_routine(self):
+ async def _connect_routine(self) -> None:
self.event_when_connected = "connected"
if self._connect_loop_wait > 0:
@@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
# and try (host, port) as a last resort
self._dns_answers = None
+ ssl_context: Optional[ssl.SSLContext]
if self.use_ssl:
ssl_context = self.get_ssl_context()
else:
@@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
- def process(self, *, forever=True, timeout=None):
+ def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the
socket(s), calling various registered callbacks, calling expired
timers, handling signal events, etc). If timeout is None, this
@@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.loop.run_until_complete(self.disconnected)
else:
- tasks = [asyncio.sleep(timeout, loop=self.loop)]
+ tasks: List[Future[bool]] = [asyncio.sleep(timeout, loop=self.loop)]
if not forever:
tasks.append(self.disconnected)
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
- def init_parser(self):
+ def init_parser(self) -> None:
"""init the XML parser. The parser must always be reset for each new
connexion
"""
@@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol):
self.xml_root = None
self.parser = ET.XMLPullParser(("start", "end"))
- def connection_made(self, transport):
+ def connection_made(self, transport: BaseTransport) -> None:
"""Called when the TCP connection has been established with the server
"""
self.event(self.event_when_connected)
- self.transport = transport
+ self.transport = cast(Transport, transport)
+ if self.transport is None:
+ raise ValueError("Transport cannot be none")
self.socket = self.transport.get_extra_info(
"ssl_object",
default=self.transport.get_extra_info("socket")
@@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol):
self.send_raw(self.stream_header)
self._dns_answers = None
- def data_received(self, data):
+ def data_received(self, data: bytes) -> None:
"""Called when incoming data is received on the socket.
We feed that data to the parser and the see if this produced any XML
@@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol):
self.send(error)
self.disconnect()
- def is_connecting(self):
+ def is_connecting(self) -> bool:
return self._current_connection_attempt is not None
- def is_connected(self):
+ def is_connected(self) -> bool:
return self.transport is not None
- def eof_received(self):
+ def eof_received(self) -> None:
"""When the TCP connection is properly closed by the remote end
"""
self.event("eof_received")
- def connection_lost(self, exception):
+ def connection_lost(self, exception: Optional[BaseException]) -> None:
"""On any kind of disconnection, initiated by us or not. This signals the
closure of the TCP connection
"""
@@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol):
self._reset_sendq()
self.event('session_end')
self._set_disconnected_future()
- self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
+ self.event("disconnected", self.disconnect_reason or exception)
- def cancel_connection_attempt(self):
+ def cancel_connection_attempt(self) -> None:
"""
Immediately cancel the current create_connection() Future.
This is useful when a client using slixmpp tries to connect
@@ -506,7 +626,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt.cancel()
self._current_connection_attempt = None
- def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
+ def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future[None]:
"""Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has
passed without a response from the server, or when the server
@@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol):
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
# schedule call below. It would fortunately be converted to `1` later
# down the call chain. Praise the implicit casts lord.
- if wait == True:
+ if wait is True:
wait = 2.0
if self.transport:
@@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self._set_disconnected_future()
self.event("disconnected", reason)
- future = Future()
+ future: Future[None] = Future()
future.set_result(None)
return future
- async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
+ async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
"""Wait until the send queue is empty before disconnecting"""
try:
await asyncio.wait_for(
@@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = reason
await self._end_stream_wait(wait)
- async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
+ async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
"""
Run abort() if we do not received the disconnected event
after a waiting time.
@@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol):
# that means the disconnect has already been handled
pass
- def abort(self):
+ def abort(self) -> None:
"""
Forcibly close the connection
"""
@@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol):
self.transport.abort()
self.event("killed")
- def reconnect(self, wait=2.0, reason="Reconnecting"):
+ def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
"""Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect()
"""
log.debug("reconnecting...")
- async def handler(event):
+ async def handler(event: Any) -> None:
# We yield here to allow synchronous handlers to work first
await asyncio.sleep(0, loop=self.loop)
self.connect()
self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason)
- def configure_socket(self):
+ def configure_socket(self) -> None:
"""Set timeout and other options for self.socket.
Meant to be overridden.
"""
pass
- def configure_dns(self, resolver, domain=None, port=None):
+ def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
"""
Configure and set options for a :class:`~dns.resolver.Resolver`
instance, and other DNS related tasks. For example, you
@@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
- def get_ssl_context(self):
+ def get_ssl_context(self) -> ssl.SSLContext:
"""
Get SSL context.
"""
@@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context
- async def start_tls(self):
+ async def start_tls(self) -> bool:
"""Perform handshakes for TLS.
If the handshake is successful, the XML stream will need
to be restarted.
"""
+ if self.transport is None:
+ raise ValueError("Transport should not be None")
self.event_when_connected = "tls_success"
ssl_context = self.get_ssl_context()
try:
@@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol):
self.connection_made(transp)
return True
- def _start_keepalive(self, event):
+ def _start_keepalive(self, event: Any) -> None:
"""Begin sending whitespace periodically to keep the connection alive.
May be disabled by setting::
@@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol):
args=(' ',),
repeat=True)
- def _remove_schedules(self, event):
+ def _remove_schedules(self, event: Any) -> None:
"""Remove some schedules that become pointless when disconnected"""
self.cancel_schedule('Whitespace Keepalive')
- def start_stream_handler(self, xml):
+ def start_stream_handler(self, xml: ET.Element) -> None:
"""Perform any initialization actions, such as handshakes,
once the stream header has been sent.
@@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
- def register_stanza(self, stanza_class):
+ def register_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Add a stanza object class as a known root stanza.
A root stanza is one that appears as a direct child of the stream's
@@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.__root_stanza.append(stanza_class)
- def remove_stanza(self, stanza_class):
+ def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Remove a stanza from being a known root stanza.
A root stanza is one that appears as a direct child of the stream's
@@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.__root_stanza.remove(stanza_class)
- def add_filter(self, mode, handler, order=None):
+ def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None:
"""Add a filter for incoming or outgoing stanzas.
These filters are applied before incoming stanzas are
@@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.__filters[mode].append(handler)
- def del_filter(self, mode, handler):
+ def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None:
"""Remove an incoming or outgoing filter."""
self.__filters[mode].remove(handler)
- def register_handler(self, handler, before=None, after=None):
+ def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None:
"""Add a stream event handler that will be executed when a matching
stanza is received.
@@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__handlers.append(handler)
handler.stream = weakref.ref(self)
- def remove_handler(self, name):
+ def remove_handler(self, name: str) -> bool:
"""Remove any stream event handlers with the given name.
:param name: The name of the handler.
@@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol):
try:
return next(self._dns_answers)
except StopIteration:
- return
+ return None
- def add_event_handler(self, name, pointer, disposable=False):
+ def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None:
"""Add a custom event handler that will be executed whenever
its event is manually triggered.
@@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__event_handlers[name] = []
self.__event_handlers[name].append((pointer, disposable))
- def del_event_handler(self, name, pointer):
+ def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None:
"""Remove a function as a handler for an event.
:param name: The name of the event.
@@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol):
# Need to keep handlers that do not use
# the given function pointer
- def filter_pointers(handler):
+ def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool:
return handler[0] != pointer
self.__event_handlers[name] = list(filter(
filter_pointers,
self.__event_handlers[name]))
- def event_handled(self, name):
+ def event_handled(self, name: str) -> int:
"""Returns the number of registered handlers for an event.
:param name: The name of the event to check.
"""
return len(self.__event_handlers.get(name, []))
- async def event_async(self, name: str, data: Any = {}):
+ async def event_async(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event, but await coroutines immediately.
This event generator should only be called in situations when
@@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol):
except Exception as e:
self.exception(e)
- def event(self, name: str, data: Any = {}):
+ def event(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event.
Coroutine handlers are wrapped into a future and sent into the
event loop for their execution, and not awaited.
@@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol):
# If the callback is a coroutine, schedule it instead of
# running it directly
if iscoroutinefunction(handler_callback):
- async def handler_callback_routine(cb):
+ async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None:
try:
await cb(data)
except Exception as e:
@@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol):
except ValueError:
pass
- def schedule(self, name, seconds, callback, args=tuple(),
- kwargs={}, repeat=False):
+ def schedule(self, name: str, seconds: int, callback: Callable[..., None],
+ args: Tuple[Any, ...] = tuple(),
+ kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None:
"""Schedule a callback function to execute after a given delay.
:param name: A unique name for the scheduled callback.
@@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol):
# canceling scheduled_events[name]
self.scheduled_events[name] = handle
- def cancel_schedule(self, name):
+ def cancel_schedule(self, name: str) -> None:
try:
handle = self.scheduled_events.pop(name)
handle.cancel()
except KeyError:
log.debug("Tried to cancel unscheduled event: %s" % (name,))
- def _safe_cb_run(self, name, cb):
+ def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None:
log.debug('Scheduled event: %s', name)
try:
cb()
except Exception as e:
self.exception(e)
- def _execute_and_reschedule(self, name, cb, seconds):
+ def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None:
"""Simple method that calls the given callback, and then schedule itself to
be called after the given number of seconds.
"""
@@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol):
name, cb, seconds)
self.scheduled_events[name] = handle
- def _execute_and_unschedule(self, name, cb):
+ def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None:
"""
Execute the callback and remove the handler for it.
"""
@@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol):
if name in self.scheduled_events:
del self.scheduled_events[name]
- def incoming_filter(self, xml):
+ def incoming_filter(self, xml: ET.Element) -> ET.Element:
"""Filter incoming XML objects before they are processed.
Possible uses include remapping namespaces, or correcting elements
@@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return xml
- def _reset_sendq(self):
+ def _reset_sendq(self) -> None:
"""Clear sending tasks on session end"""
# Cancel all pending slow send tasks
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
@@ -1042,8 +1165,8 @@ class XMLStream(asyncio.BaseProtocol):
async def _continue_slow_send(
self,
- task: asyncio.Task,
- already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]]
+ task: asyncio.Task[Optional[StanzaBase]],
+ already_used: Set[Filter]
) -> None:
"""
Used when an item in the send queue has taken too long to process.
@@ -1062,12 +1185,14 @@ class XMLStream(asyncio.BaseProtocol):
if iscoroutinefunction(filter):
data = await filter(data)
else:
+ filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
return
- if isinstance(data, ElementBase):
+ if isinstance(data, StanzaBase):
for filter in self.__filters['out_sync']:
+ filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
return
@@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.send_raw(data)
- async def run_filters(self):
+ async def run_filters(self) -> NoReturn:
"""
Background loop that processes stanzas to send.
"""
while True:
+ data: Optional[Union[StanzaBase, str]]
(data, use_filters) = await self.waiting_queue.get()
try:
- if isinstance(data, ElementBase):
+ if isinstance(data, StanzaBase):
if use_filters:
already_run_filters = set()
for filter in self.__filters['out']:
already_run_filters.add(filter)
if iscoroutinefunction(filter):
+ filter = cast(AsyncFilter, filter)
task = asyncio.create_task(filter(data))
completed, pending = await wait(
{task},
@@ -1108,19 +1235,24 @@ class XMLStream(asyncio.BaseProtocol):
"Slow coroutine, rescheduling filters"
)
data = task.result()
- else:
+ elif isinstance(data, StanzaBase):
+ filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
raise ContinueQueue('Empty stanza')
- if isinstance(data, ElementBase):
+ if isinstance(data, StanzaBase):
if use_filters:
for filter in self.__filters['out_sync']:
+ filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
raise ContinueQueue('Empty stanza')
- str_data = tostring(data.xml, xmlns=self.default_ns,
- stream=self, top_level=True)
+ if isinstance(data, StanzaBase):
+ str_data = tostring(data.xml, xmlns=self.default_ns,
+ stream=self, top_level=True)
+ else:
+ str_data = data
self.send_raw(str_data)
else:
self.send_raw(data)
@@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol):
log.error('Exception raised in send queue:', exc_info=True)
self.waiting_queue.task_done()
- def send(self, data, use_filters=True):
+ def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None:
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
- :param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
+ :param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to send on the stream.
:param bool use_filters: Indicates if outgoing filters should be
applied to the given stanza data. Disabling
@@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol):
return
self.waiting_queue.put_nowait((data, use_filters))
- def send_xml(self, data):
+ def send_xml(self, data: ET.Element) -> None:
"""Send an XML object on the stream
:param data: The :class:`~xml.etree.ElementTree.Element` XML object
to send on the stream.
"""
- return self.send(tostring(data))
+ self.send(tostring(data))
- def send_raw(self, data):
+ def send_raw(self, data: Union[str, bytes]) -> None:
"""Send raw data across the stream.
:param string data: Any bytes or utf-8 string value.
@@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol):
data = data.encode('utf-8')
self.transport.write(data)
- def _build_stanza(self, xml, default_ns=None):
+ def _build_stanza(self, xml: ET.Element,
+ default_ns: Optional[str] = None) -> StanzaBase:
"""Create a stanza object from a given XML object.
If a specialized stanza type is not found for the XML, then
@@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol):
stanza['lang'] = self.peer_default_lang
return stanza
- def _spawn_event(self, xml):
+ def _spawn_event(self, xml: ET.Element) -> None:
"""
Analyze incoming XML stanzas and convert them into stanza
objects if applicable and queue stream events to be processed
@@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol):
# Convert the raw XML object into a stanza object. If no registered
# stanza type applies, a generic StanzaBase stanza will be used.
- stanza = self._build_stanza(xml)
+ stanza: Optional[StanzaBase] = self._build_stanza(xml)
for filter in self.__filters['in']:
if stanza is not None:
+ filter = cast(SyncFilter, filter)
stanza = filter(stanza)
if stanza is None:
return
@@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol):
if not handled:
stanza.unhandled()
- def exception(self, exception):
+ def exception(self, exception: Exception) -> None:
"""Process an unknown exception.
Meant to be overridden.
@@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
- async def wait_until(self, event: str, timeout=30) -> Any:
+ async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any:
"""Utility method to wake on the next firing of an event.
(Registers a disposable handler on it)
@@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol):
:param int timeout: Timeout
:raises: :class:`asyncio.TimeoutError` when the timeout is reached
"""
- fut = asyncio.Future()
+ fut: Future[Any] = asyncio.Future()
- def result_handler(event_data):
+ def result_handler(event_data: Any) -> None:
if not fut.done():
fut.set_result(event_data)
else:
@@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol):
return await asyncio.wait_for(fut, timeout)
@contextmanager
- def event_handler(self, event: str, handler: Callable):
+ def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]:
"""
Context manager that adds then removes an event handler.
"""
self.add_event_handler(event, handler)
try:
yield
- except Exception as exc:
+ except Exception:
raise
finally:
self.del_event_handler(event, handler)
- def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future:
+ def wrap(self, coroutine: Coroutine[None, None, T]) -> Future[T]:
"""Make a Future out of a coroutine with the current loop.
:param coroutine: The coroutine to wrap.