summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--slixmpp/xmlstream/xmlstream.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 5b245e11..a67c337d 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -12,7 +12,15 @@
:license: MIT, see LICENSE for more details
"""
-from typing import Optional, Set, Callable, Any
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Union,
+)
import functools
import logging
@@ -21,7 +29,7 @@ import ssl
import weakref
import uuid
-from asyncio import iscoroutinefunction, wait
+from asyncio import iscoroutinefunction, wait, Future
import xml.etree.ElementTree as ET
@@ -228,8 +236,9 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive)
-
+
self._run_filters = None
+ self.__slow_tasks: List[Future] = []
@property
def loop(self):
@@ -465,6 +474,7 @@ class XMLStream(asyncio.BaseProtocol):
self.socket = None
# Fire the events after cleanup
if self.end_session_on_disconnect:
+ self._reset_sendq()
self.event('session_end')
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
@@ -937,6 +947,18 @@ class XMLStream(asyncio.BaseProtocol):
"""
return xml
+ def _reset_sendq(self):
+ """Clear sending tasks on session end"""
+ # Cancel all pending slow send tasks
+ log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
+ for slow_task in self.__slow_tasks:
+ slow_task.cancel()
+ self.__slow_tasks.clear()
+ # Purge pending stanzas
+ while not self.waiting_queue.empty():
+ discarded = self.waiting_queue.get_nowait()
+ log.debug('Discarded stanza: %s', discarded)
+
async def _continue_slow_send(
self,
task: asyncio.Task,
@@ -950,6 +972,7 @@ class XMLStream(asyncio.BaseProtocol):
:param set already_used: Filters already used on this outgoing stanza
"""
data = await task
+ self.__slow_tasks.remove(task)
for filter in self.__filters['out']:
if filter in already_used:
continue
@@ -990,6 +1013,7 @@ class XMLStream(asyncio.BaseProtocol):
timeout=1,
)
if pending:
+ self.slow_tasks.append(task)
asyncio.ensure_future(
self._continue_slow_send(
task,