From db48c8f4da2448343284f871ac9464c45ce0b66f Mon Sep 17 00:00:00 2001 From: mathieui Date: Sat, 3 Jul 2021 11:07:01 +0200 Subject: xmlstream: add more types --- slixmpp/xmlstream/xmlstream.py | 470 ++++++++++++++++++++++++++--------------- 1 file changed, 302 insertions(+), 168 deletions(-) diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 45d814fd..dbebf2c8 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -9,17 +9,24 @@ # :license: MIT, see LICENSE for more details from typing import ( Any, + Dict, + Awaitable, + Generator, Coroutine, Callable, - Iterable, Iterator, List, Optional, Set, Union, Tuple, + TypeVar, + NoReturn, + Type, + cast, ) +import asyncio import functools import logging import socket as Socket @@ -27,30 +34,66 @@ import ssl import weakref import uuid -import asyncio -from asyncio import iscoroutinefunction, wait, Future from contextlib import contextmanager import xml.etree.ElementTree as ET +from asyncio import ( + AbstractEventLoop, + BaseTransport, + Future, + Task, + TimerHandle, + Transport, + iscoroutinefunction, + wait, +) -from slixmpp.xmlstream import tostring +from slixmpp.types import FilterString +from slixmpp.xmlstream.tostring import tostring from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase from slixmpp.xmlstream.resolver import resolve, default_resolver +from slixmpp.xmlstream.handler.base import BaseHandler + +T = TypeVar('T') #: The time in seconds to wait before timing out waiting for response stanzas. RESPONSE_TIMEOUT = 30 log = logging.getLogger(__name__) + + class ContinueQueue(Exception): """ Exception raised in the send queue to "continue" from within an inner loop """ + class NotConnectedError(Exception): """ Raised when we try to send something over the wire but we are not connected. """ + +_T = TypeVar('_T', str, ElementBase, StanzaBase) + + +SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]] +AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]] + + +Filter = Union[ + SyncFilter, + AsyncFilter, +] + +_FiltersDict = Dict[str, List[Filter]] + +Handler = Callable[[Any], Union[ + Any, + Coroutine[Any, Any, Any] +]] + + class XMLStream(asyncio.BaseProtocol): """ An XML stream connection manager and event dispatcher. @@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol): :param int port: The port to use for the connection. Defaults to 0. """ - def __init__(self, host='', port=0): - # The asyncio.Transport object provided by the connection_made() - # callback when we are connected - self.transport = None + transport: Optional[Transport] - # The socket that is used internally by the transport object - self.socket = None + # The socket that is used internally by the transport object + socket: Optional[ssl.SSLSocket] + + # The backoff of the connect routine (increases exponentially + # after each failure) + _connect_loop_wait: float + + parser: Optional[ET.XMLPullParser] + xml_depth: int + xml_root: Optional[ET.Element] + + force_starttls: Optional[bool] + disable_starttls: Optional[bool] + + waiting_queue: asyncio.Queue[Tuple[Union[StanzaBase, str], bool]] + + # A dict of {name: handle} + scheduled_events: Dict[str, TimerHandle] + + ssl_context: ssl.SSLContext + + # The event to trigger when the create_connection() succeeds. It can + # be "connected" or "tls_success" depending on the step we are at. + event_when_connected: str + + #: The list of accepted ciphers, in OpenSSL Format. + #: It might be useful to override it for improved security + #: over the python defaults. + ciphers: Optional[str] + + #: Path to a file containing certificates for verifying the + #: server SSL certificate. A non-``None`` value will trigger + #: certificate checking. + #: + #: .. note:: + #: + #: On Mac OS X, certificates in the system keyring will + #: be consulted, even if they are not in the provided file. + ca_certs: Optional[str] + + #: Path to a file containing a client certificate to use for + #: authenticating via SASL EXTERNAL. If set, there must also + #: be a corresponding `:attr:keyfile` value. + certfile: Optional[str] + + #: Path to a file containing the private key for the selected + #: client certificate to use for authenticating via SASL EXTERNAL. + keyfile: Optional[str] + + # The asyncio event loop + _loop: Optional[AbstractEventLoop] + + #: The default port to return when querying DNS records. + default_port: int + + #: The domain to try when querying DNS records. + default_domain: str + + #: The expected name of the server, for validation. + _expected_server_name: str + _service_name: str + + #: The desired, or actual, address of the connected server. + address: Tuple[str, int] + + #: Enable connecting to the server directly over SSL, in + #: particular when the service provides two ports: one for + #: non-SSL traffic and another for SSL traffic. + use_ssl: bool + + #: If set to ``True``, attempt to use IPv6. + use_ipv6: bool - # The backoff of the connect routine (increases exponentially - # after each failure) + #: If set to ``True``, allow using the ``dnspython`` DNS library + #: if available. If set to ``False``, the builtin DNS resolver + #: will be used, even if ``dnspython`` is installed. + use_aiodns: bool + + #: Use CDATA for escaping instead of XML entities. Defaults + #: to ``False``. + use_cdata: bool + + #: The default namespace of the stream content, not of the + #: stream wrapper it + default_ns: str + + default_lang: Optional[str] + peer_default_lang: Optional[str] + + #: The namespace of the enveloping stream element. + stream_ns: str + + #: The default opening tag for the stream element. + stream_header: str + + #: The default closing tag for the stream element. + stream_footer: str + + #: If ``True``, periodically send a whitespace character over the + #: wire to keep the connection alive. Mainly useful for connections + #: traversing NAT. + whitespace_keepalive: bool + + #: The default interval between keepalive signals when + #: :attr:`whitespace_keepalive` is enabled. + whitespace_keepalive_interval: int + + #: Flag for controlling if the session can be considered ended + #: if the connection is terminated. + end_session_on_disconnect: bool + + #: A mapping of XML namespaces to well-known prefixes. + namespace_map: dict + + __root_stanza: List[Type[StanzaBase]] + __handlers: List[BaseHandler] + __event_handlers: Dict[str, List[Tuple[Handler, bool]]] + __filters: _FiltersDict + + # Current connection attempt (Future) + _current_connection_attempt: Optional[Future[None]] + + #: A list of DNS results that have not yet been tried. + _dns_answers: Optional[Iterator[Tuple[str, str, int]]] + + #: The service name to check with DNS SRV records. For + #: example, setting this to ``'xmpp-client'`` would query the + #: ``_xmpp-client._tcp`` service. + dns_service: Optional[str] + + #: The reason why we are disconnecting from the server + disconnect_reason: Optional[str] + + #: An asyncio Future being done when the stream is disconnected. + disconnected: Future[bool] + + # If the session has been started or not + _session_started: bool + # If we want to bypass the send() check (e.g. unit tests) + _always_send_everything: bool + + _run_out_filters: Optional[Future] + __slow_tasks: List[Task] + __queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]] + + def __init__(self, host: str = '', port: int = 0): + self.transport = None + self.socket = None self._connect_loop_wait = 0 self.parser = None @@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol): self.ssl_context.check_hostname = False self.ssl_context.verify_mode = ssl.CERT_NONE - # The event to trigger when the create_connection() succeeds. It can - # be "connected" or "tls_success" depending on the step we are at. self.event_when_connected = "connected" - #: The list of accepted ciphers, in OpenSSL Format. - #: It might be useful to override it for improved security - #: over the python defaults. self.ciphers = None - #: Path to a file containing certificates for verifying the - #: server SSL certificate. A non-``None`` value will trigger - #: certificate checking. - #: - #: .. note:: - #: - #: On Mac OS X, certificates in the system keyring will - #: be consulted, even if they are not in the provided file. self.ca_certs = None - #: Path to a file containing a client certificate to use for - #: authenticating via SASL EXTERNAL. If set, there must also - #: be a corresponding `:attr:keyfile` value. - self.certfile = None - - #: Path to a file containing the private key for the selected - #: client certificate to use for authenticating via SASL EXTERNAL. self.keyfile = None - self._der_cert = None - - # The asyncio event loop self._loop = None - #: The default port to return when querying DNS records. self.default_port = int(port) - - #: The domain to try when querying DNS records. self.default_domain = '' - #: The expected name of the server, for validation. self._expected_server_name = '' self._service_name = '' - #: The desired, or actual, address of the connected server. self.address = (host, int(port)) - #: Enable connecting to the server directly over SSL, in - #: particular when the service provides two ports: one for - #: non-SSL traffic and another for SSL traffic. self.use_ssl = False - - #: If set to ``True``, attempt to use IPv6. self.use_ipv6 = True - #: If set to ``True``, allow using the ``dnspython`` DNS library - #: if available. If set to ``False``, the builtin DNS resolver - #: will be used, even if ``dnspython`` is installed. self.use_aiodns = True - - #: Use CDATA for escaping instead of XML entities. Defaults - #: to ``False``. self.use_cdata = False - #: The default namespace of the stream content, not of the - #: stream wrapper itself. self.default_ns = '' self.default_lang = None self.peer_default_lang = None - #: The namespace of the enveloping stream element. self.stream_ns = '' - - #: The default opening tag for the stream element. self.stream_header = "" - - #: The default closing tag for the stream element. self.stream_footer = "" - #: If ``True``, periodically send a whitespace character over the - #: wire to keep the connection alive. Mainly useful for connections - #: traversing NAT. self.whitespace_keepalive = True - - #: The default interval between keepalive signals when - #: :attr:`whitespace_keepalive` is enabled. self.whitespace_keepalive_interval = 300 - #: Flag for controlling if the session can be considered ended - #: if the connection is terminated. self.end_session_on_disconnect = True - - #: A mapping of XML namespaces to well-known prefixes. self.namespace_map = {StanzaBase.xml_ns: 'xml'} self.__root_stanza = [] self.__handlers = [] self.__event_handlers = {} - self.__filters = {'in': [], 'out': [], 'out_sync': []} + self.__filters = { + 'in': [], 'out': [], 'out_sync': [] + } - # Current connection attempt (Future) self._current_connection_attempt = None - #: A list of DNS results that have not yet been tried. - self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None - - #: The service name to check with DNS SRV records. For - #: example, setting this to ``'xmpp-client'`` would query the - #: ``_xmpp-client._tcp`` service. + self._dns_answers = None self.dns_service = None - #: The reason why we are disconnecting from the server self.disconnect_reason = None - - #: An asyncio Future being done when the stream is disconnected. - self.disconnected: Future = Future() - - # If the session has been started or not + self.disconnected = Future() self._session_started = False - # If we want to bypass the send() check (e.g. unit tests) self._always_send_everything = False self.add_event_handler('disconnected', self._remove_schedules) @@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol): self.add_event_handler('session_start', self._set_session_start) self.add_event_handler('session_resumed', self._set_session_start) - self._run_out_filters: Optional[Future] = None - self.__slow_tasks: List[Future] = [] - self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = [] + self._run_out_filters = None + self.__slow_tasks = [] + self.__queued_stanzas = [] @property - def loop(self): + def loop(self) -> AbstractEventLoop: if self._loop is None: self._loop = asyncio.get_event_loop() return self._loop @loop.setter - def loop(self, value): + def loop(self, value: AbstractEventLoop) -> None: self._loop = value - def new_id(self): + def new_id(self) -> str: """Generate and return a new stream ID in hexadecimal form. Many stanzas, handlers, or matchers may require unique @@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol): """ return uuid.uuid4().hex - def _set_session_start(self, event): + def _set_session_start(self, event: Any) -> None: """ On session start, queue all pending stanzas to be sent. """ @@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol): self.waiting_queue.put_nowait(stanza) self.__queued_stanzas = [] - def _set_disconnected(self, event): + def _set_disconnected(self, event: Any) -> None: self._session_started = False - def _set_disconnected_future(self): + def _set_disconnected_future(self) -> None: """Set the self.disconnected future on disconnect""" if not self.disconnected.done(): self.disconnected.set_result(True) self.disconnected = asyncio.Future() - def connect(self, host='', port=0, use_ssl=False, - force_starttls=True, disable_starttls=False): + def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False, + force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None: """Create a new socket and connect to the server. :param host: The name of the desired server for the connection. @@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol): loop=self.loop, ) - async def _connect_routine(self): + async def _connect_routine(self) -> None: self.event_when_connected = "connected" if self._connect_loop_wait > 0: @@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol): # and try (host, port) as a last resort self._dns_answers = None + ssl_context: Optional[ssl.SSLContext] if self.use_ssl: ssl_context = self.get_ssl_context() else: @@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol): loop=self.loop, ) - def process(self, *, forever=True, timeout=None): + def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None: """Process all the available XMPP events (receiving or sending data on the socket(s), calling various registered callbacks, calling expired timers, handling signal events, etc). If timeout is None, this @@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol): else: self.loop.run_until_complete(self.disconnected) else: - tasks = [asyncio.sleep(timeout, loop=self.loop)] + tasks: List[Future[bool]] = [asyncio.sleep(timeout, loop=self.loop)] if not forever: tasks.append(self.disconnected) self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) - def init_parser(self): + def init_parser(self) -> None: """init the XML parser. The parser must always be reset for each new connexion """ @@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol): self.xml_root = None self.parser = ET.XMLPullParser(("start", "end")) - def connection_made(self, transport): + def connection_made(self, transport: BaseTransport) -> None: """Called when the TCP connection has been established with the server """ self.event(self.event_when_connected) - self.transport = transport + self.transport = cast(Transport, transport) + if self.transport is None: + raise ValueError("Transport cannot be none") self.socket = self.transport.get_extra_info( "ssl_object", default=self.transport.get_extra_info("socket") @@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol): self.send_raw(self.stream_header) self._dns_answers = None - def data_received(self, data): + def data_received(self, data: bytes) -> None: """Called when incoming data is received on the socket. We feed that data to the parser and the see if this produced any XML @@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol): self.send(error) self.disconnect() - def is_connecting(self): + def is_connecting(self) -> bool: return self._current_connection_attempt is not None - def is_connected(self): + def is_connected(self) -> bool: return self.transport is not None - def eof_received(self): + def eof_received(self) -> None: """When the TCP connection is properly closed by the remote end """ self.event("eof_received") - def connection_lost(self, exception): + def connection_lost(self, exception: Optional[BaseException]) -> None: """On any kind of disconnection, initiated by us or not. This signals the closure of the TCP connection """ @@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol): self._reset_sendq() self.event('session_end') self._set_disconnected_future() - self.event("disconnected", self.disconnect_reason or exception and exception.strerror) + self.event("disconnected", self.disconnect_reason or exception) - def cancel_connection_attempt(self): + def cancel_connection_attempt(self) -> None: """ Immediately cancel the current create_connection() Future. This is useful when a client using slixmpp tries to connect @@ -506,7 +626,7 @@ class XMLStream(asyncio.BaseProtocol): self._current_connection_attempt.cancel() self._current_connection_attempt = None - def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future: + def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future[None]: """Close the XML stream and wait for an acknowldgement from the server for at most `wait` seconds. After the given number of seconds has passed without a response from the server, or when the server @@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol): # `disconnect(wait=True)` for ages. This doesn't mean anything to the # schedule call below. It would fortunately be converted to `1` later # down the call chain. Praise the implicit casts lord. - if wait == True: + if wait is True: wait = 2.0 if self.transport: @@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol): else: self._set_disconnected_future() self.event("disconnected", reason) - future = Future() + future: Future[None] = Future() future.set_result(None) return future - async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float): + async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None: """Wait until the send queue is empty before disconnecting""" try: await asyncio.wait_for( @@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol): self.disconnect_reason = reason await self._end_stream_wait(wait) - async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None): + async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None: """ Run abort() if we do not received the disconnected event after a waiting time. @@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol): # that means the disconnect has already been handled pass - def abort(self): + def abort(self) -> None: """ Forcibly close the connection """ @@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol): self.transport.abort() self.event("killed") - def reconnect(self, wait=2.0, reason="Reconnecting"): + def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None: """Calls disconnect(), and once we are disconnected (after the timeout, or when the server acknowledgement is received), call connect() """ log.debug("reconnecting...") - async def handler(event): + async def handler(event: Any) -> None: # We yield here to allow synchronous handlers to work first await asyncio.sleep(0, loop=self.loop) self.connect() self.add_event_handler('disconnected', handler, disposable=True) self.disconnect(wait, reason) - def configure_socket(self): + def configure_socket(self) -> None: """Set timeout and other options for self.socket. Meant to be overridden. """ pass - def configure_dns(self, resolver, domain=None, port=None): + def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None: """ Configure and set options for a :class:`~dns.resolver.Resolver` instance, and other DNS related tasks. For example, you @@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol): """ pass - def get_ssl_context(self): + def get_ssl_context(self) -> ssl.SSLContext: """ Get SSL context. """ @@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol): return self.ssl_context - async def start_tls(self): + async def start_tls(self) -> bool: """Perform handshakes for TLS. If the handshake is successful, the XML stream will need to be restarted. """ + if self.transport is None: + raise ValueError("Transport should not be None") self.event_when_connected = "tls_success" ssl_context = self.get_ssl_context() try: @@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol): self.connection_made(transp) return True - def _start_keepalive(self, event): + def _start_keepalive(self, event: Any) -> None: """Begin sending whitespace periodically to keep the connection alive. May be disabled by setting:: @@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol): args=(' ',), repeat=True) - def _remove_schedules(self, event): + def _remove_schedules(self, event: Any) -> None: """Remove some schedules that become pointless when disconnected""" self.cancel_schedule('Whitespace Keepalive') - def start_stream_handler(self, xml): + def start_stream_handler(self, xml: ET.Element) -> None: """Perform any initialization actions, such as handshakes, once the stream header has been sent. @@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol): """ pass - def register_stanza(self, stanza_class): + def register_stanza(self, stanza_class: Type[StanzaBase]) -> None: """Add a stanza object class as a known root stanza. A root stanza is one that appears as a direct child of the stream's @@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol): """ self.__root_stanza.append(stanza_class) - def remove_stanza(self, stanza_class): + def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None: """Remove a stanza from being a known root stanza. A root stanza is one that appears as a direct child of the stream's @@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol): """ self.__root_stanza.remove(stanza_class) - def add_filter(self, mode, handler, order=None): + def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None: """Add a filter for incoming or outgoing stanzas. These filters are applied before incoming stanzas are @@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol): else: self.__filters[mode].append(handler) - def del_filter(self, mode, handler): + def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None: """Remove an incoming or outgoing filter.""" self.__filters[mode].remove(handler) - def register_handler(self, handler, before=None, after=None): + def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None: """Add a stream event handler that will be executed when a matching stanza is received. @@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol): self.__handlers.append(handler) handler.stream = weakref.ref(self) - def remove_handler(self, name): + def remove_handler(self, name: str) -> bool: """Remove any stream event handlers with the given name. :param name: The name of the handler. @@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol): try: return next(self._dns_answers) except StopIteration: - return + return None - def add_event_handler(self, name, pointer, disposable=False): + def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None: """Add a custom event handler that will be executed whenever its event is manually triggered. @@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol): self.__event_handlers[name] = [] self.__event_handlers[name].append((pointer, disposable)) - def del_event_handler(self, name, pointer): + def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None: """Remove a function as a handler for an event. :param name: The name of the event. @@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol): # Need to keep handlers that do not use # the given function pointer - def filter_pointers(handler): + def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool: return handler[0] != pointer self.__event_handlers[name] = list(filter( filter_pointers, self.__event_handlers[name])) - def event_handled(self, name): + def event_handled(self, name: str) -> int: """Returns the number of registered handlers for an event. :param name: The name of the event to check. """ return len(self.__event_handlers.get(name, [])) - async def event_async(self, name: str, data: Any = {}): + async def event_async(self, name: str, data: Any = {}) -> None: """Manually trigger a custom event, but await coroutines immediately. This event generator should only be called in situations when @@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol): except Exception as e: self.exception(e) - def event(self, name: str, data: Any = {}): + def event(self, name: str, data: Any = {}) -> None: """Manually trigger a custom event. Coroutine handlers are wrapped into a future and sent into the event loop for their execution, and not awaited. @@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol): # If the callback is a coroutine, schedule it instead of # running it directly if iscoroutinefunction(handler_callback): - async def handler_callback_routine(cb): + async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None: try: await cb(data) except Exception as e: @@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol): except ValueError: pass - def schedule(self, name, seconds, callback, args=tuple(), - kwargs={}, repeat=False): + def schedule(self, name: str, seconds: int, callback: Callable[..., None], + args: Tuple[Any, ...] = tuple(), + kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None: """Schedule a callback function to execute after a given delay. :param name: A unique name for the scheduled callback. @@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol): # canceling scheduled_events[name] self.scheduled_events[name] = handle - def cancel_schedule(self, name): + def cancel_schedule(self, name: str) -> None: try: handle = self.scheduled_events.pop(name) handle.cancel() except KeyError: log.debug("Tried to cancel unscheduled event: %s" % (name,)) - def _safe_cb_run(self, name, cb): + def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None: log.debug('Scheduled event: %s', name) try: cb() except Exception as e: self.exception(e) - def _execute_and_reschedule(self, name, cb, seconds): + def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None: """Simple method that calls the given callback, and then schedule itself to be called after the given number of seconds. """ @@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol): name, cb, seconds) self.scheduled_events[name] = handle - def _execute_and_unschedule(self, name, cb): + def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None: """ Execute the callback and remove the handler for it. """ @@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol): if name in self.scheduled_events: del self.scheduled_events[name] - def incoming_filter(self, xml): + def incoming_filter(self, xml: ET.Element) -> ET.Element: """Filter incoming XML objects before they are processed. Possible uses include remapping namespaces, or correcting elements @@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol): """ return xml - def _reset_sendq(self): + def _reset_sendq(self) -> None: """Clear sending tasks on session end""" # Cancel all pending slow send tasks log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks)) @@ -1042,8 +1165,8 @@ class XMLStream(asyncio.BaseProtocol): async def _continue_slow_send( self, - task: asyncio.Task, - already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]] + task: asyncio.Task[Optional[StanzaBase]], + already_used: Set[Filter] ) -> None: """ Used when an item in the send queue has taken too long to process. @@ -1062,12 +1185,14 @@ class XMLStream(asyncio.BaseProtocol): if iscoroutinefunction(filter): data = await filter(data) else: + filter = cast(SyncFilter, filter) data = filter(data) if data is None: return - if isinstance(data, ElementBase): + if isinstance(data, StanzaBase): for filter in self.__filters['out_sync']: + filter = cast(SyncFilter, filter) data = filter(data) if data is None: return @@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol): else: self.send_raw(data) - async def run_filters(self): + async def run_filters(self) -> NoReturn: """ Background loop that processes stanzas to send. """ while True: + data: Optional[Union[StanzaBase, str]] (data, use_filters) = await self.waiting_queue.get() try: - if isinstance(data, ElementBase): + if isinstance(data, StanzaBase): if use_filters: already_run_filters = set() for filter in self.__filters['out']: already_run_filters.add(filter) if iscoroutinefunction(filter): + filter = cast(AsyncFilter, filter) task = asyncio.create_task(filter(data)) completed, pending = await wait( {task}, @@ -1108,19 +1235,24 @@ class XMLStream(asyncio.BaseProtocol): "Slow coroutine, rescheduling filters" ) data = task.result() - else: + elif isinstance(data, StanzaBase): + filter = cast(SyncFilter, filter) data = filter(data) if data is None: raise ContinueQueue('Empty stanza') - if isinstance(data, ElementBase): + if isinstance(data, StanzaBase): if use_filters: for filter in self.__filters['out_sync']: + filter = cast(SyncFilter, filter) data = filter(data) if data is None: raise ContinueQueue('Empty stanza') - str_data = tostring(data.xml, xmlns=self.default_ns, - stream=self, top_level=True) + if isinstance(data, StanzaBase): + str_data = tostring(data.xml, xmlns=self.default_ns, + stream=self, top_level=True) + else: + str_data = data self.send_raw(str_data) else: self.send_raw(data) @@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol): log.error('Exception raised in send queue:', exc_info=True) self.waiting_queue.task_done() - def send(self, data, use_filters=True): + def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None: """A wrapper for :meth:`send_raw()` for sending stanza objects. - :param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` + :param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` stanza to send on the stream. :param bool use_filters: Indicates if outgoing filters should be applied to the given stanza data. Disabling @@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol): return self.waiting_queue.put_nowait((data, use_filters)) - def send_xml(self, data): + def send_xml(self, data: ET.Element) -> None: """Send an XML object on the stream :param data: The :class:`~xml.etree.ElementTree.Element` XML object to send on the stream. """ - return self.send(tostring(data)) + self.send(tostring(data)) - def send_raw(self, data): + def send_raw(self, data: Union[str, bytes]) -> None: """Send raw data across the stream. :param string data: Any bytes or utf-8 string value. @@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol): data = data.encode('utf-8') self.transport.write(data) - def _build_stanza(self, xml, default_ns=None): + def _build_stanza(self, xml: ET.Element, + default_ns: Optional[str] = None) -> StanzaBase: """Create a stanza object from a given XML object. If a specialized stanza type is not found for the XML, then @@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol): stanza['lang'] = self.peer_default_lang return stanza - def _spawn_event(self, xml): + def _spawn_event(self, xml: ET.Element) -> None: """ Analyze incoming XML stanzas and convert them into stanza objects if applicable and queue stream events to be processed @@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol): # Convert the raw XML object into a stanza object. If no registered # stanza type applies, a generic StanzaBase stanza will be used. - stanza = self._build_stanza(xml) + stanza: Optional[StanzaBase] = self._build_stanza(xml) for filter in self.__filters['in']: if stanza is not None: + filter = cast(SyncFilter, filter) stanza = filter(stanza) if stanza is None: return @@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol): if not handled: stanza.unhandled() - def exception(self, exception): + def exception(self, exception: Exception) -> None: """Process an unknown exception. Meant to be overridden. @@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol): """ pass - async def wait_until(self, event: str, timeout=30) -> Any: + async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any: """Utility method to wake on the next firing of an event. (Registers a disposable handler on it) @@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol): :param int timeout: Timeout :raises: :class:`asyncio.TimeoutError` when the timeout is reached """ - fut = asyncio.Future() + fut: Future[Any] = asyncio.Future() - def result_handler(event_data): + def result_handler(event_data: Any) -> None: if not fut.done(): fut.set_result(event_data) else: @@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol): return await asyncio.wait_for(fut, timeout) @contextmanager - def event_handler(self, event: str, handler: Callable): + def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]: """ Context manager that adds then removes an event handler. """ self.add_event_handler(event, handler) try: yield - except Exception as exc: + except Exception: raise finally: self.del_event_handler(event, handler) - def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future: + def wrap(self, coroutine: Coroutine[None, None, T]) -> Future[T]: """Make a Future out of a coroutine with the current loop. :param coroutine: The coroutine to wrap. -- cgit v1.2.3