diff options
-rw-r--r-- | sleekxmpp/xmlstream/handler/callback.py | 4 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/statemachine.py | 104 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 18 | ||||
-rw-r--r-- | tests/test_statemachine.py | 67 |
4 files changed, 143 insertions, 50 deletions
diff --git a/sleekxmpp/xmlstream/handler/callback.py b/sleekxmpp/xmlstream/handler/callback.py index ea5acb5b..7b8c98d2 100644 --- a/sleekxmpp/xmlstream/handler/callback.py +++ b/sleekxmpp/xmlstream/handler/callback.py @@ -20,12 +20,12 @@ class Callback(base.BaseHandler): def prerun(self, payload): # prerun actually calls run?!? WTF! Then it gets run AGAIN! base.BaseHandler.prerun(self, payload) if self._instream: - logging.debug('callback "%s" prerun', self.name) +# logging.debug('callback "%s" prerun', self.name) self.run(payload, True) def run(self, payload, instream=False): if not self._instream or instream: - logging.debug('callback "%s" run', self.name) +# logging.debug('callback "%s" run', self.name) base.BaseHandler.run(self, payload) #if self._thread: # x = threading.Thread(name="Callback_%s" % self.name, target=self._pointer, args=(payload,)) diff --git a/sleekxmpp/xmlstream/statemachine.py b/sleekxmpp/xmlstream/statemachine.py index 67b514a2..9412d5ad 100644 --- a/sleekxmpp/xmlstream/statemachine.py +++ b/sleekxmpp/xmlstream/statemachine.py @@ -5,27 +5,31 @@ See the file license.txt for copying permission. """ -from __future__ import with_statement import threading import time import logging +log = logging.getLogger(__name__) + class StateMachine(object): def __init__(self, states=[]): - self.lock = threading.Condition(threading.RLock()) + self.lock = threading.Lock() + self.notifier = threading.Event() self.__states= [] self.addStates(states) self.__default_state = self.__states[0] self.__current_state = self.__default_state def addStates(self, states): - with self.lock: + self.lock.acquire() + try: for state in states: if state in self.__states: raise IndexError("The state '%s' is already in the StateMachine." % state) self.__states.append( state ) + finally: self.lock.release() def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ): @@ -78,30 +82,33 @@ class StateMachine(object): if not to_state in self.__states: raise ValueError( "StateMachine does not contain to_state %s." % to_state ) - with self.lock: - start = time.time() - while not self.__current_state in from_states: - # detect timeout: - if time.time() >= start + wait: return False - self.lock.wait(wait) - + + start = time.time() + while not self.__current_state in from_states or not self.lock.acquire(False): + # detect timeout: + if time.time() >= start + wait: return False + self.notifier.wait(wait) + + try: # lock is acquired; all other threads will return false or wait until notify/timeout + self.notifier.clear() if self.__current_state in from_states: # should always be True due to lock - - return_val = True + # Note that func might throw an exception, but that's OK, it aborts the transition - if func is not None: return_val = func(*args,**kwargs) + return_val = func(*args,**kwargs) if func is not None else True # some 'false' value returned from func, # indicating that transition should not occur: if not return_val: return return_val - logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state) - self.__current_state = to_state - self.lock.notify_all() + log.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state) + self._set_state( to_state ) return return_val # some 'true' value returned by func or True if func was None else: - logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" ) + log.error( "StateMachine bug!! The lock should ensure this doesn't happen!" ) return False + finally: + self.notifier.set() + self.lock.release() def transition_ctx(self, from_state, to_state, wait=0.0): @@ -148,7 +155,15 @@ class StateMachine(object): def ensure_any(self, states, wait=0.0): ''' - Ensure we are currently in one of the given `states` + Ensure we are currently in one of the given `states` or wait until + we enter one of those states. + + Note that due to the nature of the function, you cannot guarantee that + the entirety of some operation completes while you remain in a given + state. That would require acquiring and holding a lock, which + would mean no other threads could do the same. (You'd essentially + be serializing all of the threads that are 'ensuring' their tasks + occurred in some state. ''' if not (isinstance(states,tuple) or isinstance(states,list)): raise ValueError('states arg should be a tuple or list') @@ -157,13 +172,17 @@ class StateMachine(object): if not state in self.__states: raise ValueError( "StateMachine does not contain state '%s'" % state ) - with self.lock: - start = time.time() - while not self.__current_state in states: - # detect timeout: - if time.time() >= start + wait: return False - self.lock.wait(wait) - return self.__current_state in states # should always be True due to lock + # Locking never really gained us anything here, since the lock was released + # before the function returned anyways. The only thing it _did_ do was + # increase the probability that this function would block for longer than + # intended if a `transition` function or context was running while holding + # the lock. + start = time.time() + while not self.__current_state in states: + # detect timeout: + if time.time() >= start + wait: return False + self.notifier.wait(wait) + return True def reset(self): @@ -202,33 +221,36 @@ class _StateCtx: self.from_state = from_state self.to_state = to_state self.wait = wait - self._timeout = False + self._locked = False def __enter__(self): - self.state_machine.lock.acquire() start = time.time() - while not self.state_machine[ self.from_state ]: + while not self.state_machine[ self.from_state ] or not self.state_machine.lock.acquire(False): # detect timeout: if time.time() >= start + self.wait: - logging.debug('StateMachine timeout while waiting for state: %s', self.from_state ) - self._timeout = True # to indicate we should not transition + log.debug('StateMachine timeout while waiting for state: %s', self.from_state ) return False - self.state_machine.lock.wait(self.wait) + self.state_machine.notifier.wait(self.wait) - logging.debug('StateMachine entered context in state: %s', + self._locked = True # lock has been acquired at this point + self.state_machine.notifier.clear() + log.debug('StateMachine entered context in state: %s', self.state_machine.current_state() ) return True def __exit__(self, exc_type, exc_val, exc_tb): if exc_val is not None: - logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s", - self.state_machine.current_state(), exc_type.__name__, exc_val ) - elif not self._timeout: - logging.debug(' ==== TRANSITION %s -> %s', - self.state_machine.current_state(), self.to_state) - self.state_machine._set_state( self.to_state ) - - self.state_machine.lock.notify_all() - self.state_machine.lock.release() + log.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s", + self.state_machine.current_state(), exc_type.__name__, exc_val ) + + if self._locked: + if exc_val is None: + log.debug(' ==== TRANSITION %s -> %s', + self.state_machine.current_state(), self.to_state) + self.state_machine._set_state( self.to_state ) + + self.state_machine.notifier.set() + self.state_machine.lock.release() + return False # re-raise any exception diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index 025a6cbf..842dfee2 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -58,8 +58,7 @@ class XMLStream(object): global ssl_support self.ssl_support = ssl_support self.escape_quotes = escape_quotes - self.state = statemachine.StateMachine(('disconnected','connecting', - 'connected')) + self.state = statemachine.StateMachine(('disconnected','connected')) self.should_reconnect = True self.setSocket(socket) @@ -92,9 +91,11 @@ class XMLStream(object): def setSocket(self, socket): "Set the socket" self.socket = socket - if socket is not None and self.state.transition('disconnected','connecting'): - self.filesocket = socket.makefile('rb', 0) # ElementTree.iterparse requires a file. 0 buffer files have to be binary - self.state.transition('connecting','connected') + if socket is not None: + with self.state.transition_ctx('disconnected','connected') as locked: + if not locked: raise Exception('Already connected') + # ElementTree.iterparse requires a file. 0 buffer files have to be binary + self.filesocket = socket.makefile('rb', 0) def setFileSocket(self, filesocket): self.filesocket = filesocket @@ -235,6 +236,9 @@ class XMLStream(object): logging.debug("System interrupt detected") self.shutdown() self.eventqueue.put(('quit', None, None)) + except cElementTree.XMLParserError: + logging.warn('XML RCV parsing error!', exc_info=1) + # don't restart the stream on an XML parse error. except: logging.exception('Unexpected error in RCV thread') if self.should_reconnect: @@ -356,11 +360,11 @@ class XMLStream(object): # TODO inefficient linear search; performance might be improved by hashtable lookup for handler in self.__handlers: if handler.match(stanza): - logging.debug('matched stanza to handler %s', handler.name) +# logging.debug('matched stanza to handler %s', handler.name) handler.prerun(stanza) self.eventqueue.put(('stanza', handler, stanza)) if handler.checkDelete(): - logging.debug('deleting callback %s', handler.name) +# logging.debug('deleting callback %s', handler.name) self.__handlers.pop(self.__handlers.index(handler)) unhandled = False if unhandled: diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py index e44b8e48..0046dd02 100644 --- a/tests/test_statemachine.py +++ b/tests/test_statemachine.py @@ -256,6 +256,73 @@ class testStateMachine(unittest.TestCase): self.assertTrue( s['three'] ) + def testTransitionsDontUnintentionallyBlock(self): + ''' + There was a bug where a long-running transition (e.g. one with a 'func' + arg or a `transition_ctx` call would cause any `transition` or `ensure` + call to block since the lock is acquired before checking the current + state. Attempts to acquire the mutex need to be non-blocking so when a + timeout is _not_ given, the caller can return immediately. At the same + time, threads that _do_ want to wait need the ability to be notified + (to avoid waiting beyond when the lock is released) so we've moved to a + combination of a plain-ol `threading.Lock` to act as mutex, and a + `threading.Event` to perform notification for threads who choose to wait. + ''' + + s = sm.StateMachine(('one','two','three')) + + with s.transition_ctx('two','three') as result: + self.failIf( result ) + self.assertTrue( s['one'] ) + self.failIf( s.current_state in ('two','three') ) + + self.assertTrue( s['one'] ) + + statuses = {'t1':"not started", + 't2':'not started'} + + def t1(): + print 'thread 1 started' + # no wait, so this should 'return False' immediately. + self.failIf( s.transition('two','three') ) + statuses['t1'] = 'complete' + print 'thread 1 transitioned' + + def t2(): + print 'thread 2 started' + self.failIf( s['two'] ) + self.failIf( s['three'] ) + # we want this thread to acquire the lock, but for + # the second thread not to wait on the first. + with s.transition_ctx('one','two', 10) as locked: + statuses['t2'] = 'started' + print 'thread 2 has entered context' + self.assertTrue( locked ) + # give thread1 a chance to complete while this + # thread still owns the lock + time.sleep(5) + self.assertTrue( s['two'] ) + statuses['t2'] = 'complete' + + t1 = threading.Thread(target=t1) + t2 = threading.Thread(target=t2) + + t2.start() # this should acquire the lock + time.sleep(.2) + self.assertEqual( 'started', statuses['t2'] ) + t1.start() # but it shouldn't prevent thread 1 from completing + time.sleep(1) + + self.assertEqual( 'complete', statuses['t1'] ) + + t1.join() + t2.join() + + self.assertEqual( 'complete', statuses['t2'] ) + + self.assertTrue( s['two'] ) + + suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine) if __name__ == '__main__': unittest.main() |