diff options
-rw-r--r-- | slixmpp/xmlstream/xmlstream.py | 58 |
1 files changed, 42 insertions, 16 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index a67c337d..9fb38f46 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -489,7 +489,7 @@ class XMLStream(asyncio.BaseProtocol): self._current_connection_attempt.cancel() self._current_connection_attempt = None - def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None: + def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future: """Close the XML stream and wait for an acknowldgement from the server for at most `wait` seconds. After the given number of seconds has passed without a response from the server, or when the server @@ -497,10 +497,13 @@ class XMLStream(asyncio.BaseProtocol): called. If wait is 0.0, this will call abort() directly without closing the stream. - Does nothing if we are not connected. + Does nothing but trigger the disconnected event if we are not connected. :param wait: Time to wait for a response from the server. - + :param reason: An optional reason for the disconnect. + :param ignore_send_queue: Boolean to toggle if we want to ignore + the in-flight stanzas and disconnect immediately. + :return: A future that ends when all code involved in the disconnect has ended """ # Compat: docs/getting_started/sendlogout.rst has been promoting # `disconnect(wait=True)` for ages. This doesn't mean anything to the @@ -510,39 +513,63 @@ class XMLStream(asyncio.BaseProtocol): wait = 2.0 if self.transport: + self.disconnect_reason = reason if self.waiting_queue.empty() or ignore_send_queue: - self.disconnect_reason = reason self.cancel_connection_attempt() - if wait > 0.0: - self.send_raw(self.stream_footer) - self.schedule('Disconnect wait', wait, - self.abort, repeat=False) + return asyncio.ensure_future( + self._end_stream_wait(wait, reason=reason), + loop=self.loop, + ) else: - asyncio.ensure_future( + return asyncio.ensure_future( self._consume_send_queue_before_disconnecting(reason, wait), loop=self.loop, ) else: self.event("disconnected", reason) + future = Future() + future.set_result(None) + return future async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float): """Wait until the send queue is empty before disconnecting""" - await self.waiting_queue.join() + try: + await asyncio.wait_for( + self.waiting_queue.join(), + wait, + loop=self.loop + ) + except asyncio.TimeoutError: + wait = 0 # we already consumed the timeout self.disconnect_reason = reason - self.cancel_connection_attempt() - if wait > 0.0: + await self._end_stream_wait(wait) + + async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None): + """ + Run abort() if we do not received the disconnected event + after a waiting time. + + :param wait: The waiting time (defaults to 2) + """ + try: self.send_raw(self.stream_footer) - self.schedule('Disconnect wait', wait, - self.abort, repeat=False) + await self.wait_until('disconnected', wait) + except asyncio.TimeoutError: + self.abort() + except NotConnectedError: + # We are not connected when sending the end of stream + # that means the disconnect has already been handled + pass def abort(self): """ Forcibly close the connection """ - self.cancel_connection_attempt() if self.transport: + self.cancel_connection_attempt() self.transport.close() self.transport.abort() + self.transport = None self.event("killed") self.disconnected.set_result(True) self.disconnected = asyncio.Future() @@ -661,7 +688,6 @@ class XMLStream(asyncio.BaseProtocol): def _remove_schedules(self, event): """Remove some schedules that become pointless when disconnected""" self.cancel_schedule('Whitespace Keepalive') - self.cancel_schedule('Disconnect wait') def start_stream_handler(self, xml): """Perform any initialization actions, such as handshakes, |