summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLink Mauve <linkmauve@linkmauve.fr>2021-01-29 16:11:29 +0100
committerLink Mauve <linkmauve@linkmauve.fr>2021-01-29 16:11:29 +0100
commitdbcd0c6050f6c50c24ff1a86129a29133371373e (patch)
tree1836f0e2981b2d7f64130bc11038ea9c9ee9cf86
parent370abb1d983bf1aabff523d4bbc5c9b89a2becb8 (diff)
parentf93af07882d19fd60af1696ccfa784ac4c03aa42 (diff)
downloadslixmpp-dbcd0c6050f6c50c24ff1a86129a29133371373e.tar.gz
slixmpp-dbcd0c6050f6c50c24ff1a86129a29133371373e.tar.bz2
slixmpp-dbcd0c6050f6c50c24ff1a86129a29133371373e.tar.xz
slixmpp-dbcd0c6050f6c50c24ff1a86129a29133371373e.zip
Merge branch 'reconnect-logic-doomed' into 'master'
fix reconnect logic See merge request poezio/slixmpp!104
-rw-r--r--slixmpp/plugins/xep_0198/stream_management.py31
-rw-r--r--slixmpp/xmlstream/xmlstream.py134
2 files changed, 114 insertions, 51 deletions
diff --git a/slixmpp/plugins/xep_0198/stream_management.py b/slixmpp/plugins/xep_0198/stream_management.py
index 0200646a..1344235a 100644
--- a/slixmpp/plugins/xep_0198/stream_management.py
+++ b/slixmpp/plugins/xep_0198/stream_management.py
@@ -174,6 +174,9 @@ class XEP_0198(BasePlugin):
def send_ack(self):
"""Send the current ack count to the server."""
+ if not self.xmpp.transport:
+ log.debug('Disconnected: not sending ack')
+ return
ack = stanza.Ack(self.xmpp)
ack['h'] = self.handled
self.xmpp.send_raw(str(ack))
@@ -198,20 +201,7 @@ class XEP_0198(BasePlugin):
# We've already negotiated stream management,
# so no need to do it again.
return False
- if not self.sm_id:
- if 'bind' in self.xmpp.features:
- enable = stanza.Enable(self.xmpp)
- enable['resume'] = self.allow_resume
- enable.send()
- log.debug("enabling SM")
-
- waiter = Waiter('enabled_or_failed',
- MatchMany([
- MatchXPath(stanza.Enabled.tag_name()),
- MatchXPath(stanza.Failed.tag_name())]))
- self.xmpp.register_handler(waiter)
- result = await waiter.wait()
- elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
+ if self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
resume = stanza.Resume(self.xmpp)
resume['h'] = self.handled
resume['previd'] = self.sm_id
@@ -229,6 +219,19 @@ class XEP_0198(BasePlugin):
result = await waiter.wait()
if result is not None and result.name == 'resumed':
return True
+ self.xmpp.event("session_end")
+ if 'bind' in self.xmpp.features:
+ enable = stanza.Enable(self.xmpp)
+ enable['resume'] = self.allow_resume
+ enable.send()
+ log.debug("enabling SM")
+
+ waiter = Waiter('enabled_or_failed',
+ MatchMany([
+ MatchXPath(stanza.Enabled.tag_name()),
+ MatchXPath(stanza.Failed.tag_name())]))
+ self.xmpp.register_handler(waiter)
+ result = await waiter.wait()
return False
def _handle_enabled(self, stanza):
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 6b890729..5074aa8c 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -12,7 +12,15 @@
:license: MIT, see LICENSE for more details
"""
-from typing import Optional, Set, Callable, Any
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Union,
+)
import functools
import logging
@@ -21,7 +29,7 @@ import ssl
import weakref
import uuid
-from asyncio import iscoroutinefunction, wait
+from asyncio import iscoroutinefunction, wait, Future
import xml.etree.ElementTree as ET
@@ -224,12 +232,13 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = None
#: An asyncio Future being done when the stream is disconnected.
- self.disconnected = asyncio.Future()
+ self.disconnected: Future = Future()
self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive)
-
- self._run_filters = None
+
+ self._run_out_filters: Optional[Future] = None
+ self.__slow_tasks: List[Future] = []
@property
def loop(self):
@@ -250,6 +259,12 @@ class XMLStream(asyncio.BaseProtocol):
"""
return uuid.uuid4().hex
+ def _set_disconnected_future(self):
+ """Set the self.disconnected future on disconnect"""
+ if not self.disconnected.done():
+ self.disconnected.set_result(True)
+ self.disconnected = asyncio.Future()
+
def connect(self, host='', port=0, use_ssl=False,
force_starttls=True, disable_starttls=False):
"""Create a new socket and connect to the server.
@@ -272,8 +287,8 @@ class XMLStream(asyncio.BaseProtocol):
localhost
"""
- if self._run_filters is None:
- self._run_filters = asyncio.ensure_future(
+ if self._run_out_filters is None or self._run_out_filters.done():
+ self._run_out_filters = asyncio.ensure_future(
self.run_filters(),
loop=self.loop,
)
@@ -418,10 +433,10 @@ class XMLStream(asyncio.BaseProtocol):
if self.xml_depth == 0:
# The stream's root element has closed,
# terminating the stream.
- self.end_session_on_disconnect = True
log.debug("End of stream received")
self.disconnect_reason = "End of stream"
self.abort()
+ return
elif self.xml_depth == 1:
# A stanza is an XML element that is a direct child of
# the root element, hence the check of depth == 1
@@ -463,11 +478,11 @@ class XMLStream(asyncio.BaseProtocol):
self.parser = None
self.transport = None
self.socket = None
- if self._run_filters:
- self._run_filters.cancel()
# Fire the events after cleanup
if self.end_session_on_disconnect:
+ self._reset_sendq()
self.event('session_end')
+ self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
def cancel_connection_attempt(self):
@@ -480,10 +495,8 @@ class XMLStream(asyncio.BaseProtocol):
if self._current_connection_attempt:
self._current_connection_attempt.cancel()
self._current_connection_attempt = None
- if self._run_filters:
- self._run_filters.cancel()
- def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None:
+ def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
"""Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has
passed without a response from the server, or when the server
@@ -491,10 +504,13 @@ class XMLStream(asyncio.BaseProtocol):
called. If wait is 0.0, this will call abort() directly without closing
the stream.
- Does nothing if we are not connected.
+ Does nothing but trigger the disconnected event if we are not connected.
:param wait: Time to wait for a response from the server.
-
+ :param reason: An optional reason for the disconnect.
+ :param ignore_send_queue: Boolean to toggle if we want to ignore
+ the in-flight stanzas and disconnect immediately.
+ :return: A future that ends when all code involved in the disconnect has ended
"""
# Compat: docs/getting_started/sendlogout.rst has been promoting
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
@@ -504,50 +520,75 @@ class XMLStream(asyncio.BaseProtocol):
wait = 2.0
if self.transport:
+ self.disconnect_reason = reason
if self.waiting_queue.empty() or ignore_send_queue:
- self.disconnect_reason = reason
self.cancel_connection_attempt()
- if wait > 0.0:
- self.send_raw(self.stream_footer)
- self.schedule('Disconnect wait', wait,
- self.abort, repeat=False)
+ return asyncio.ensure_future(
+ self._end_stream_wait(wait, reason=reason),
+ loop=self.loop,
+ )
else:
- asyncio.ensure_future(
+ return asyncio.ensure_future(
self._consume_send_queue_before_disconnecting(reason, wait),
loop=self.loop,
)
else:
+ self._set_disconnected_future()
self.event("disconnected", reason)
+ future = Future()
+ future.set_result(None)
+ return future
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
"""Wait until the send queue is empty before disconnecting"""
- await self.waiting_queue.join()
+ try:
+ await asyncio.wait_for(
+ self.waiting_queue.join(),
+ wait,
+ loop=self.loop
+ )
+ except asyncio.TimeoutError:
+ wait = 0 # we already consumed the timeout
self.disconnect_reason = reason
- self.cancel_connection_attempt()
- if wait > 0.0:
+ await self._end_stream_wait(wait)
+
+ async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
+ """
+ Run abort() if we do not received the disconnected event
+ after a waiting time.
+
+ :param wait: The waiting time (defaults to 2)
+ """
+ try:
self.send_raw(self.stream_footer)
- self.schedule('Disconnect wait', wait,
- self.abort, repeat=False)
+ await self.wait_until('disconnected', wait)
+ except asyncio.TimeoutError:
+ self.abort()
+ except NotConnectedError:
+ # We are not connected when sending the end of stream
+ # that means the disconnect has already been handled
+ pass
def abort(self):
"""
Forcibly close the connection
"""
- self.cancel_connection_attempt()
if self.transport:
+ self.cancel_connection_attempt()
self.transport.close()
self.transport.abort()
self.event("killed")
- self.disconnected.set_result(True)
- self.disconnected = asyncio.Future()
- self.event("disconnected", self.disconnect_reason)
def reconnect(self, wait=2.0, reason="Reconnecting"):
"""Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect()
"""
log.debug("reconnecting...")
- self.add_event_handler('disconnected', lambda event: self.connect(), disposable=True)
+ async def handler(event):
+ # We yield here to allow synchronous handlers to work first
+ await asyncio.sleep(0, loop=self.loop)
+ self.connect()
+ self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason)
def configure_socket(self):
@@ -655,7 +696,6 @@ class XMLStream(asyncio.BaseProtocol):
def _remove_schedules(self, event):
"""Remove some schedules that become pointless when disconnected"""
self.cancel_schedule('Whitespace Keepalive')
- self.cancel_schedule('Disconnect wait')
def start_stream_handler(self, xml):
"""Perform any initialization actions, such as handshakes,
@@ -833,7 +873,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
log.debug("Event triggered: %s", name)
- handlers = self.__event_handlers.get(name, [])
+ handlers = self.__event_handlers.get(name, [])[:]
for handler in handlers:
handler_callback, disposable = handler
old_exception = getattr(data, 'exception', None)
@@ -941,6 +981,18 @@ class XMLStream(asyncio.BaseProtocol):
"""
return xml
+ def _reset_sendq(self):
+ """Clear sending tasks on session end"""
+ # Cancel all pending slow send tasks
+ log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
+ for slow_task in self.__slow_tasks:
+ slow_task.cancel()
+ self.__slow_tasks.clear()
+ # Purge pending stanzas
+ while not self.waiting_queue.empty():
+ discarded = self.waiting_queue.get_nowait()
+ log.debug('Discarded stanza: %s', discarded)
+
async def _continue_slow_send(
self,
task: asyncio.Task,
@@ -954,6 +1006,7 @@ class XMLStream(asyncio.BaseProtocol):
:param set already_used: Filters already used on this outgoing stanza
"""
data = await task
+ self.__slow_tasks.remove(task)
for filter in self.__filters['out']:
if filter in already_used:
continue
@@ -975,7 +1028,6 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.send_raw(data)
-
async def run_filters(self):
"""
Background loop that processes stanzas to send.
@@ -995,11 +1047,13 @@ class XMLStream(asyncio.BaseProtocol):
timeout=1,
)
if pending:
+ self.slow_tasks.append(task)
asyncio.ensure_future(
self._continue_slow_send(
task,
already_run_filters
- )
+ ),
+ loop=self.loop,
)
raise Exception("Slow coro, rescheduling")
data = task.result()
@@ -1142,9 +1196,15 @@ class XMLStream(asyncio.BaseProtocol):
:param int timeout: Timeout
"""
fut = asyncio.Future()
+ def result_handler(event_data):
+ if not fut.done():
+ fut.set_result(event_data)
+ else:
+ log.debug("Future registered on event '%s' was alredy done", event)
+
self.add_event_handler(
event,
- fut.set_result,
+ result_handler,
disposable=True,
)
- return await asyncio.wait_for(fut, timeout)
+ return await asyncio.wait_for(fut, timeout, loop=self.loop)