summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--slixmpp/plugins/xep_0198/stream_management.py93
1 files changed, 48 insertions, 45 deletions
diff --git a/slixmpp/plugins/xep_0198/stream_management.py b/slixmpp/plugins/xep_0198/stream_management.py
index acf37cd7..fbc9e023 100644
--- a/slixmpp/plugins/xep_0198/stream_management.py
+++ b/slixmpp/plugins/xep_0198/stream_management.py
@@ -6,8 +6,8 @@
See the file LICENSE for copying permission.
"""
+import asyncio
import logging
-import threading
import collections
from slixmpp.stanza import Message, Presence, Iq, StreamFeatures
@@ -70,15 +70,10 @@ class XEP_0198(BasePlugin):
return
self.window_counter = self.window
- self.window_counter_lock = threading.Lock()
- self.enabled = threading.Event()
+ self.enabled = False
self.unacked_queue = collections.deque()
- self.seq_lock = threading.Lock()
- self.handled_lock = threading.Lock()
- self.ack_lock = threading.Lock()
-
register_stanza_plugin(StreamFeatures, stanza.StreamManagement)
self.xmpp.register_stanza(stanza.Enable)
self.xmpp.register_stanza(stanza.Enabled)
@@ -161,7 +156,7 @@ class XEP_0198(BasePlugin):
def session_end(self, event):
"""Reset stream management state."""
- self.enabled.clear()
+ self.enabled = False
self.unacked_queue.clear()
self.sm_id = None
self.handled = 0
@@ -171,15 +166,15 @@ class XEP_0198(BasePlugin):
def send_ack(self):
"""Send the current ack count to the server."""
ack = stanza.Ack(self.xmpp)
- with self.handled_lock:
- ack['h'] = self.handled
+ ack['h'] = self.handled
self.xmpp.send_raw(str(ack))
def request_ack(self, e=None):
"""Request an ack from the server."""
req = stanza.RequestAck(self.xmpp)
- self.xmpp.send_queue.put(str(req))
+ self.xmpp.send_raw(str(req))
+ @asyncio.coroutine
def _handle_sm_feature(self, features):
"""
Enable or resume stream management.
@@ -196,13 +191,21 @@ class XEP_0198(BasePlugin):
return False
if not self.sm_id:
if 'bind' in self.xmpp.features:
- self.enabled.set()
enable = stanza.Enable(self.xmpp)
enable['resume'] = self.allow_resume
enable.send()
+ self.enabled = True
self.handled = 0
- elif self.sm_id and self.allow_resume:
- self.enabled.set()
+ self.unacked_queue.clear()
+
+ waiter = Waiter('enabled_or_failed',
+ MatchMany([
+ MatchXPath(stanza.Enabled.tag_name()),
+ MatchXPath(stanza.Failed.tag_name())]))
+ self.xmpp.register_handler(waiter)
+ result = yield from waiter.wait()
+ elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
+ self.enabled = True
resume = stanza.Resume(self.xmpp)
resume['h'] = self.handled
resume['previd'] = self.sm_id
@@ -216,7 +219,7 @@ class XEP_0198(BasePlugin):
MatchXPath(stanza.Resumed.tag_name()),
MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter)
- result = waiter.wait()
+ result = yield from waiter.wait()
if result is not None and result.name == 'resumed':
return True
return False
@@ -250,7 +253,7 @@ class XEP_0198(BasePlugin):
Raises an :term:`sm_failed` event.
"""
- self.enabled.clear()
+ self.enabled = False
self.unacked_queue.clear()
self.xmpp.event('sm_failed', stanza)
@@ -262,21 +265,24 @@ class XEP_0198(BasePlugin):
if ack['h'] == self.last_ack:
return
- with self.ack_lock:
- num_acked = (ack['h'] - self.last_ack) % MAX_SEQ
- num_unacked = len(self.unacked_queue)
- log.debug("Ack: %s, Last Ack: %s, " + \
- "Unacked: %s, Num Acked: %s, " + \
- "Remaining: %s",
- ack['h'],
- self.last_ack,
- num_unacked,
- num_acked,
- num_unacked - num_acked)
- for x in range(num_acked):
- seq, stanza = self.unacked_queue.popleft()
- self.xmpp.event('stanza_acked', stanza)
- self.last_ack = ack['h']
+ num_acked = (ack['h'] - self.last_ack) % MAX_SEQ
+ num_unacked = len(self.unacked_queue)
+ log.debug("Ack: %s, Last Ack: %s, " + \
+ "Unacked: %s, Num Acked: %s, " + \
+ "Remaining: %s",
+ ack['h'],
+ self.last_ack,
+ num_unacked,
+ num_acked,
+ num_unacked - num_acked)
+ if num_acked > len(self.unacked_queue) or num_acked < 0:
+ log.error('Inconsistent sequence numbers from the server,'
+ ' ignoring and replacing ours with them.')
+ num_acked = len(self.unacked_queue)
+ for x in range(num_acked):
+ seq, stanza = self.unacked_queue.popleft()
+ self.xmpp.event('stanza_acked', stanza)
+ self.last_ack = ack['h']
def _handle_request_ack(self, req):
"""Handle an ack request by sending an ack."""
@@ -284,30 +290,27 @@ class XEP_0198(BasePlugin):
def _handle_incoming(self, stanza):
"""Increment the handled counter for each inbound stanza."""
- if not self.enabled.is_set():
+ if not self.enabled:
return stanza
if isinstance(stanza, (Message, Presence, Iq)):
- with self.handled_lock:
- # Sequence numbers are mod 2^32
- self.handled = (self.handled + 1) % MAX_SEQ
+ # Sequence numbers are mod 2^32
+ self.handled = (self.handled + 1) % MAX_SEQ
return stanza
def _handle_outgoing(self, stanza):
"""Store outgoing stanzas in a queue to be acked."""
- if not self.enabled.is_set():
+ if not self.enabled:
return stanza
if isinstance(stanza, (Message, Presence, Iq)):
seq = None
- with self.seq_lock:
- # Sequence numbers are mod 2^32
- self.seq = (self.seq + 1) % MAX_SEQ
- seq = self.seq
+ # Sequence numbers are mod 2^32
+ self.seq = (self.seq + 1) % MAX_SEQ
+ seq = self.seq
self.unacked_queue.append((seq, stanza))
- with self.window_counter_lock:
- self.window_counter -= 1
- if self.window_counter == 0:
- self.window_counter = self.window
- self.request_ack()
+ self.window_counter -= 1
+ if self.window_counter == 0:
+ self.window_counter = self.window
+ self.request_ack()
return stanza