diff options
Diffstat (limited to 'slixmpp/xmlstream/xmlstream.py')
-rw-r--r-- | slixmpp/xmlstream/xmlstream.py | 198 |
1 files changed, 156 insertions, 42 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 17d23ff2..af494903 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -12,6 +12,8 @@ :license: MIT, see LICENSE for more details """ +from typing import Optional, Set, Callable + import functools import logging import socket as Socket @@ -19,6 +21,8 @@ import ssl import weakref import uuid +from asyncio import iscoroutinefunction, wait + import xml.etree.ElementTree as ET from slixmpp.xmlstream.asyncio import asyncio @@ -30,6 +34,10 @@ from slixmpp.xmlstream.resolver import resolve, default_resolver RESPONSE_TIMEOUT = 30 log = logging.getLogger(__name__) +class ContinueQueue(Exception): + """ + Exception raised in the send queue to "continue" from within an inner loop + """ class NotConnectedError(Exception): """ @@ -81,6 +89,8 @@ class XMLStream(asyncio.BaseProtocol): self.force_starttls = None self.disable_starttls = None + self.waiting_queue = asyncio.Queue() + # A dict of {name: handle} self.scheduled_events = {} @@ -199,11 +209,6 @@ class XMLStream(asyncio.BaseProtocol): self.__event_handlers = {} self.__filters = {'in': [], 'out': [], 'out_sync': []} - self._id = 0 - - #: We use an ID prefix to ensure that all ID values are unique. - self._id_prefix = '%s-' % uuid.uuid4() - # Current connection attempt (Future) self._current_connection_attempt = None @@ -223,6 +228,8 @@ class XMLStream(asyncio.BaseProtocol): self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('session_start', self._start_keepalive) + + self._run_filters = None @property def loop(self): @@ -241,12 +248,7 @@ class XMLStream(asyncio.BaseProtocol): ID values. Using this method ensures that all new ID values are unique in this stream. """ - self._id += 1 - return self.get_id() - - def get_id(self): - """Return the current unique stream ID in hexadecimal form.""" - return "%s%X" % (self._id_prefix, self._id) + return uuid.uuid4().hex def connect(self, host='', port=0, use_ssl=False, force_starttls=True, disable_starttls=False): @@ -271,8 +273,15 @@ class XMLStream(asyncio.BaseProtocol): localhost """ + if self._run_filters is None: + self._run_filters = asyncio.ensure_future( + self.run_filters(), + loop=self.loop, + ) + self.disconnect_reason = None self.cancel_connection_attempt() + self.connect_loop_wait = 0 if host and port: self.address = (host, int(port)) try: @@ -297,6 +306,10 @@ class XMLStream(asyncio.BaseProtocol): async def _connect_routine(self): self.event_when_connected = "connected" + if self.connect_loop_wait > 0: + self.event('reconnect_delay', self.connect_loop_wait) + await asyncio.sleep(self.connect_loop_wait, loop=self.loop) + record = await self.pick_dns_answer(self.default_domain) if record is not None: host, address, dns_port = record @@ -313,7 +326,6 @@ class XMLStream(asyncio.BaseProtocol): else: ssl_context = None - await asyncio.sleep(self.connect_loop_wait, loop=self.loop) if self._current_connection_attempt is None: return try: @@ -372,6 +384,7 @@ class XMLStream(asyncio.BaseProtocol): "ssl_object", default=self.transport.get_extra_info("socket") ) + self._current_connection_attempt = None self.init_parser() self.send_raw(self.stream_header) self.dns_answers = None @@ -430,6 +443,9 @@ class XMLStream(asyncio.BaseProtocol): self.send(error) self.disconnect() + def is_connecting(self): + return self._current_connection_attempt is not None + def is_connected(self): return self.transport is not None @@ -451,6 +467,8 @@ class XMLStream(asyncio.BaseProtocol): self.parser = None self.transport = None self.socket = None + if self._run_filters: + self._run_filters.cancel() def cancel_connection_attempt(self): """ @@ -462,11 +480,14 @@ class XMLStream(asyncio.BaseProtocol): if self._current_connection_attempt: self._current_connection_attempt.cancel() self._current_connection_attempt = None + if self._run_filters: + self._run_filters.cancel() + - def disconnect(self, wait=2.0, reason=None): + def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None: """Close the XML stream and wait for an acknowldgement from the server for at most `wait` seconds. After the given number of seconds has - passed without a response from the serveur, or when the server + passed without a response from the server, or when the server successfully responds with a closure of its own stream, abort() is called. If wait is 0.0, this will call abort() directly without closing the stream. @@ -476,13 +497,38 @@ class XMLStream(asyncio.BaseProtocol): :param wait: Time to wait for a response from the server. """ + # Compat: docs/getting_started/sendlogout.rst has been promoting + # `disconnect(wait=True)` for ages. This doesn't mean anything to the + # schedule call below. It would fortunately be converted to `1` later + # down the call chain. Praise the implicit casts lord. + if wait == True: + wait = 2.0 + + if self.transport: + if self.waiting_queue.empty() or ignore_send_queue: + self.disconnect_reason = reason + self.cancel_connection_attempt() + if wait > 0.0: + self.send_raw(self.stream_footer) + self.schedule('Disconnect wait', wait, + self.abort, repeat=False) + else: + asyncio.ensure_future( + self._consume_send_queue_before_disconnecting(reason, wait), + loop=self.loop, + ) + else: + self.event("disconnected", reason) + + async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float): + """Wait until the send queue is empty before disconnecting""" + await self.waiting_queue.join() self.disconnect_reason = reason self.cancel_connection_attempt() - if self.transport: - if wait > 0.0: - self.send_raw(self.stream_footer) - self.schedule('Disconnect wait', wait, - self.abort, repeat=False) + if wait > 0.0: + self.send_raw(self.stream_footer) + self.schedule('Disconnect wait', wait, + self.abort, repeat=False) def abort(self): """ @@ -495,14 +541,15 @@ class XMLStream(asyncio.BaseProtocol): self.event("killed") self.disconnected.set_result(True) self.disconnected = asyncio.Future() + self.event("disconnected", self.disconnect_reason) def reconnect(self, wait=2.0, reason="Reconnecting"): """Calls disconnect(), and once we are disconnected (after the timeout, or when the server acknowledgement is received), call connect() """ log.debug("reconnecting...") - self.disconnect(wait, reason) self.add_event_handler('disconnected', lambda event: self.connect(), disposable=True) + self.disconnect(wait, reason) def configure_socket(self): """Set timeout and other options for self.socket. @@ -790,7 +837,7 @@ class XMLStream(asyncio.BaseProtocol): # If the callback is a coroutine, schedule it instead of # running it directly - if asyncio.iscoroutinefunction(handler_callback): + if iscoroutinefunction(handler_callback): async def handler_callback_routine(cb): try: await cb(data) @@ -877,7 +924,9 @@ class XMLStream(asyncio.BaseProtocol): Execute the callback and remove the handler for it. """ self._safe_cb_run(name, cb) - del self.scheduled_events[name] + # workaround for specific events which unschedule themselves + if name in self.scheduled_events: + del self.scheduled_events[name] def incoming_filter(self, xml): """Filter incoming XML objects before they are processed. @@ -889,11 +938,93 @@ class XMLStream(asyncio.BaseProtocol): """ return xml + async def _continue_slow_send( + self, + task: asyncio.Task, + already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]] + ) -> None: + """ + Used when an item in the send queue has taken too long to process. + + This is away from the send queue and can take as much time as needed. + :param asyncio.Task task: the Task wrapping the coroutine + :param set already_used: Filters already used on this outgoing stanza + """ + data = await task + for filter in self.__filters['out']: + if filter in already_used: + continue + if iscoroutinefunction(filter): + data = await task + else: + data = filter(data) + if data is None: + return + + if isinstance(data, ElementBase): + for filter in self.__filters['out_sync']: + data = filter(data) + if data is None: + return + str_data = tostring(data.xml, xmlns=self.default_ns, + stream=self, top_level=True) + self.send_raw(str_data) + else: + self.send_raw(data) + + + async def run_filters(self): + """ + Background loop that processes stanzas to send. + """ + while True: + (data, use_filters) = await self.waiting_queue.get() + try: + if isinstance(data, ElementBase): + if use_filters: + already_run_filters = set() + for filter in self.__filters['out']: + already_run_filters.add(filter) + if iscoroutinefunction(filter): + task = asyncio.create_task(filter(data)) + completed, pending = await wait( + {task}, + timeout=1, + ) + if pending: + asyncio.ensure_future( + self._continue_slow_send( + task, + already_run_filters + ) + ) + raise Exception("Slow coro, rescheduling") + data = task.result() + else: + data = filter(data) + if data is None: + raise ContinueQueue('Empty stanza') + + if isinstance(data, ElementBase): + if use_filters: + for filter in self.__filters['out_sync']: + data = filter(data) + if data is None: + raise ContinueQueue('Empty stanza') + str_data = tostring(data.xml, xmlns=self.default_ns, + stream=self, top_level=True) + self.send_raw(str_data) + else: + self.send_raw(data) + except ContinueQueue as exc: + log.debug('Stanza in send queue not sent: %s', exc) + except Exception: + log.error('Exception raised in send queue:', exc_info=True) + self.waiting_queue.task_done() + def send(self, data, use_filters=True): """A wrapper for :meth:`send_raw()` for sending stanza objects. - May optionally block until an expected response is received. - :param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` stanza to send on the stream. :param bool use_filters: Indicates if outgoing filters should be @@ -901,24 +1032,7 @@ class XMLStream(asyncio.BaseProtocol): filters is useful when resending stanzas. Defaults to ``True``. """ - if isinstance(data, ElementBase): - if use_filters: - for filter in self.__filters['out']: - data = filter(data) - if data is None: - return - - if isinstance(data, ElementBase): - if use_filters: - for filter in self.__filters['out_sync']: - data = filter(data) - if data is None: - return - str_data = tostring(data.xml, xmlns=self.default_ns, - stream=self, top_level=True) - self.send_raw(str_data) - else: - self.send_raw(data) + self.waiting_queue.put_nowait((data, use_filters)) def send_xml(self, data): """Send an XML object on the stream |