summaryrefslogtreecommitdiff
path: root/slixmpp/xmlstream/xmlstream.py
diff options
context:
space:
mode:
Diffstat (limited to 'slixmpp/xmlstream/xmlstream.py')
-rw-r--r--slixmpp/xmlstream/xmlstream.py198
1 files changed, 156 insertions, 42 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 17d23ff2..af494903 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -12,6 +12,8 @@
:license: MIT, see LICENSE for more details
"""
+from typing import Optional, Set, Callable
+
import functools
import logging
import socket as Socket
@@ -19,6 +21,8 @@ import ssl
import weakref
import uuid
+from asyncio import iscoroutinefunction, wait
+
import xml.etree.ElementTree as ET
from slixmpp.xmlstream.asyncio import asyncio
@@ -30,6 +34,10 @@ from slixmpp.xmlstream.resolver import resolve, default_resolver
RESPONSE_TIMEOUT = 30
log = logging.getLogger(__name__)
+class ContinueQueue(Exception):
+ """
+ Exception raised in the send queue to "continue" from within an inner loop
+ """
class NotConnectedError(Exception):
"""
@@ -81,6 +89,8 @@ class XMLStream(asyncio.BaseProtocol):
self.force_starttls = None
self.disable_starttls = None
+ self.waiting_queue = asyncio.Queue()
+
# A dict of {name: handle}
self.scheduled_events = {}
@@ -199,11 +209,6 @@ class XMLStream(asyncio.BaseProtocol):
self.__event_handlers = {}
self.__filters = {'in': [], 'out': [], 'out_sync': []}
- self._id = 0
-
- #: We use an ID prefix to ensure that all ID values are unique.
- self._id_prefix = '%s-' % uuid.uuid4()
-
# Current connection attempt (Future)
self._current_connection_attempt = None
@@ -223,6 +228,8 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive)
+
+ self._run_filters = None
@property
def loop(self):
@@ -241,12 +248,7 @@ class XMLStream(asyncio.BaseProtocol):
ID values. Using this method ensures that all new ID values
are unique in this stream.
"""
- self._id += 1
- return self.get_id()
-
- def get_id(self):
- """Return the current unique stream ID in hexadecimal form."""
- return "%s%X" % (self._id_prefix, self._id)
+ return uuid.uuid4().hex
def connect(self, host='', port=0, use_ssl=False,
force_starttls=True, disable_starttls=False):
@@ -271,8 +273,15 @@ class XMLStream(asyncio.BaseProtocol):
localhost
"""
+ if self._run_filters is None:
+ self._run_filters = asyncio.ensure_future(
+ self.run_filters(),
+ loop=self.loop,
+ )
+
self.disconnect_reason = None
self.cancel_connection_attempt()
+ self.connect_loop_wait = 0
if host and port:
self.address = (host, int(port))
try:
@@ -297,6 +306,10 @@ class XMLStream(asyncio.BaseProtocol):
async def _connect_routine(self):
self.event_when_connected = "connected"
+ if self.connect_loop_wait > 0:
+ self.event('reconnect_delay', self.connect_loop_wait)
+ await asyncio.sleep(self.connect_loop_wait, loop=self.loop)
+
record = await self.pick_dns_answer(self.default_domain)
if record is not None:
host, address, dns_port = record
@@ -313,7 +326,6 @@ class XMLStream(asyncio.BaseProtocol):
else:
ssl_context = None
- await asyncio.sleep(self.connect_loop_wait, loop=self.loop)
if self._current_connection_attempt is None:
return
try:
@@ -372,6 +384,7 @@ class XMLStream(asyncio.BaseProtocol):
"ssl_object",
default=self.transport.get_extra_info("socket")
)
+ self._current_connection_attempt = None
self.init_parser()
self.send_raw(self.stream_header)
self.dns_answers = None
@@ -430,6 +443,9 @@ class XMLStream(asyncio.BaseProtocol):
self.send(error)
self.disconnect()
+ def is_connecting(self):
+ return self._current_connection_attempt is not None
+
def is_connected(self):
return self.transport is not None
@@ -451,6 +467,8 @@ class XMLStream(asyncio.BaseProtocol):
self.parser = None
self.transport = None
self.socket = None
+ if self._run_filters:
+ self._run_filters.cancel()
def cancel_connection_attempt(self):
"""
@@ -462,11 +480,14 @@ class XMLStream(asyncio.BaseProtocol):
if self._current_connection_attempt:
self._current_connection_attempt.cancel()
self._current_connection_attempt = None
+ if self._run_filters:
+ self._run_filters.cancel()
+
- def disconnect(self, wait=2.0, reason=None):
+ def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None:
"""Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has
- passed without a response from the serveur, or when the server
+ passed without a response from the server, or when the server
successfully responds with a closure of its own stream, abort() is
called. If wait is 0.0, this will call abort() directly without closing
the stream.
@@ -476,13 +497,38 @@ class XMLStream(asyncio.BaseProtocol):
:param wait: Time to wait for a response from the server.
"""
+ # Compat: docs/getting_started/sendlogout.rst has been promoting
+ # `disconnect(wait=True)` for ages. This doesn't mean anything to the
+ # schedule call below. It would fortunately be converted to `1` later
+ # down the call chain. Praise the implicit casts lord.
+ if wait == True:
+ wait = 2.0
+
+ if self.transport:
+ if self.waiting_queue.empty() or ignore_send_queue:
+ self.disconnect_reason = reason
+ self.cancel_connection_attempt()
+ if wait > 0.0:
+ self.send_raw(self.stream_footer)
+ self.schedule('Disconnect wait', wait,
+ self.abort, repeat=False)
+ else:
+ asyncio.ensure_future(
+ self._consume_send_queue_before_disconnecting(reason, wait),
+ loop=self.loop,
+ )
+ else:
+ self.event("disconnected", reason)
+
+ async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
+ """Wait until the send queue is empty before disconnecting"""
+ await self.waiting_queue.join()
self.disconnect_reason = reason
self.cancel_connection_attempt()
- if self.transport:
- if wait > 0.0:
- self.send_raw(self.stream_footer)
- self.schedule('Disconnect wait', wait,
- self.abort, repeat=False)
+ if wait > 0.0:
+ self.send_raw(self.stream_footer)
+ self.schedule('Disconnect wait', wait,
+ self.abort, repeat=False)
def abort(self):
"""
@@ -495,14 +541,15 @@ class XMLStream(asyncio.BaseProtocol):
self.event("killed")
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
+ self.event("disconnected", self.disconnect_reason)
def reconnect(self, wait=2.0, reason="Reconnecting"):
"""Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect()
"""
log.debug("reconnecting...")
- self.disconnect(wait, reason)
self.add_event_handler('disconnected', lambda event: self.connect(), disposable=True)
+ self.disconnect(wait, reason)
def configure_socket(self):
"""Set timeout and other options for self.socket.
@@ -790,7 +837,7 @@ class XMLStream(asyncio.BaseProtocol):
# If the callback is a coroutine, schedule it instead of
# running it directly
- if asyncio.iscoroutinefunction(handler_callback):
+ if iscoroutinefunction(handler_callback):
async def handler_callback_routine(cb):
try:
await cb(data)
@@ -877,7 +924,9 @@ class XMLStream(asyncio.BaseProtocol):
Execute the callback and remove the handler for it.
"""
self._safe_cb_run(name, cb)
- del self.scheduled_events[name]
+ # workaround for specific events which unschedule themselves
+ if name in self.scheduled_events:
+ del self.scheduled_events[name]
def incoming_filter(self, xml):
"""Filter incoming XML objects before they are processed.
@@ -889,11 +938,93 @@ class XMLStream(asyncio.BaseProtocol):
"""
return xml
+ async def _continue_slow_send(
+ self,
+ task: asyncio.Task,
+ already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]]
+ ) -> None:
+ """
+ Used when an item in the send queue has taken too long to process.
+
+ This is away from the send queue and can take as much time as needed.
+ :param asyncio.Task task: the Task wrapping the coroutine
+ :param set already_used: Filters already used on this outgoing stanza
+ """
+ data = await task
+ for filter in self.__filters['out']:
+ if filter in already_used:
+ continue
+ if iscoroutinefunction(filter):
+ data = await task
+ else:
+ data = filter(data)
+ if data is None:
+ return
+
+ if isinstance(data, ElementBase):
+ for filter in self.__filters['out_sync']:
+ data = filter(data)
+ if data is None:
+ return
+ str_data = tostring(data.xml, xmlns=self.default_ns,
+ stream=self, top_level=True)
+ self.send_raw(str_data)
+ else:
+ self.send_raw(data)
+
+
+ async def run_filters(self):
+ """
+ Background loop that processes stanzas to send.
+ """
+ while True:
+ (data, use_filters) = await self.waiting_queue.get()
+ try:
+ if isinstance(data, ElementBase):
+ if use_filters:
+ already_run_filters = set()
+ for filter in self.__filters['out']:
+ already_run_filters.add(filter)
+ if iscoroutinefunction(filter):
+ task = asyncio.create_task(filter(data))
+ completed, pending = await wait(
+ {task},
+ timeout=1,
+ )
+ if pending:
+ asyncio.ensure_future(
+ self._continue_slow_send(
+ task,
+ already_run_filters
+ )
+ )
+ raise Exception("Slow coro, rescheduling")
+ data = task.result()
+ else:
+ data = filter(data)
+ if data is None:
+ raise ContinueQueue('Empty stanza')
+
+ if isinstance(data, ElementBase):
+ if use_filters:
+ for filter in self.__filters['out_sync']:
+ data = filter(data)
+ if data is None:
+ raise ContinueQueue('Empty stanza')
+ str_data = tostring(data.xml, xmlns=self.default_ns,
+ stream=self, top_level=True)
+ self.send_raw(str_data)
+ else:
+ self.send_raw(data)
+ except ContinueQueue as exc:
+ log.debug('Stanza in send queue not sent: %s', exc)
+ except Exception:
+ log.error('Exception raised in send queue:', exc_info=True)
+ self.waiting_queue.task_done()
+
def send(self, data, use_filters=True):
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
- May optionally block until an expected response is received.
-
:param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
stanza to send on the stream.
:param bool use_filters: Indicates if outgoing filters should be
@@ -901,24 +1032,7 @@ class XMLStream(asyncio.BaseProtocol):
filters is useful when resending stanzas.
Defaults to ``True``.
"""
- if isinstance(data, ElementBase):
- if use_filters:
- for filter in self.__filters['out']:
- data = filter(data)
- if data is None:
- return
-
- if isinstance(data, ElementBase):
- if use_filters:
- for filter in self.__filters['out_sync']:
- data = filter(data)
- if data is None:
- return
- str_data = tostring(data.xml, xmlns=self.default_ns,
- stream=self, top_level=True)
- self.send_raw(str_data)
- else:
- self.send_raw(data)
+ self.waiting_queue.put_nowait((data, use_filters))
def send_xml(self, data):
"""Send an XML object on the stream