summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sleekxmpp/xmlstream/statemachine.py92
-rw-r--r--tests/test_statemachine.py67
2 files changed, 123 insertions, 36 deletions
diff --git a/sleekxmpp/xmlstream/statemachine.py b/sleekxmpp/xmlstream/statemachine.py
index 67b514a2..590abedc 100644
--- a/sleekxmpp/xmlstream/statemachine.py
+++ b/sleekxmpp/xmlstream/statemachine.py
@@ -5,7 +5,6 @@
See the file license.txt for copying permission.
"""
-from __future__ import with_statement
import threading
import time
import logging
@@ -14,18 +13,21 @@ import logging
class StateMachine(object):
def __init__(self, states=[]):
- self.lock = threading.Condition(threading.RLock())
+ self.lock = threading.Lock()
+ self.notifier = threading.Event()
self.__states= []
self.addStates(states)
self.__default_state = self.__states[0]
self.__current_state = self.__default_state
def addStates(self, states):
- with self.lock:
+ self.lock.acquire()
+ try:
for state in states:
if state in self.__states:
raise IndexError("The state '%s' is already in the StateMachine." % state)
self.__states.append( state )
+ finally: self.lock.release()
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
@@ -78,30 +80,33 @@ class StateMachine(object):
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
- with self.lock:
- start = time.time()
- while not self.__current_state in from_states:
- # detect timeout:
- if time.time() >= start + wait: return False
- self.lock.wait(wait)
-
+
+ start = time.time()
+ while not self.__current_state in from_states or not self.lock.acquire(False):
+ # detect timeout:
+ if time.time() >= start + wait: return False
+ self.notifier.wait(wait)
+
+ try: # lock is acquired; all other threads will return false or wait until notify/timeout
+ self.notifier.clear()
if self.__current_state in from_states: # should always be True due to lock
-
- return_val = True
+
# Note that func might throw an exception, but that's OK, it aborts the transition
- if func is not None: return_val = func(*args,**kwargs)
+ return_val = func(*args,**kwargs) if func is not None else True
# some 'false' value returned from func,
# indicating that transition should not occur:
if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
- self.__current_state = to_state
- self.lock.notify_all()
+ self._set_state( to_state )
return return_val # some 'true' value returned by func or True if func was None
else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False
+ finally:
+ self.notifier.set()
+ self.lock.release()
def transition_ctx(self, from_state, to_state, wait=0.0):
@@ -148,7 +153,15 @@ class StateMachine(object):
def ensure_any(self, states, wait=0.0):
'''
- Ensure we are currently in one of the given `states`
+ Ensure we are currently in one of the given `states` or wait until
+ we enter one of those states.
+
+ Note that due to the nature of the function, you cannot guarantee that
+ the entirety of some operation completes while you remain in a given
+ state. That would require acquiring and holding a lock, which
+ would mean no other threads could do the same. (You'd essentially
+ be serializing all of the threads that are 'ensuring' their tasks
+ occurred in some state.
'''
if not (isinstance(states,tuple) or isinstance(states,list)):
raise ValueError('states arg should be a tuple or list')
@@ -157,13 +170,17 @@ class StateMachine(object):
if not state in self.__states:
raise ValueError( "StateMachine does not contain state '%s'" % state )
- with self.lock:
- start = time.time()
- while not self.__current_state in states:
- # detect timeout:
- if time.time() >= start + wait: return False
- self.lock.wait(wait)
- return self.__current_state in states # should always be True due to lock
+ # Locking never really gained us anything here, since the lock was released
+ # before the function returned anyways. The only thing it _did_ do was
+ # increase the probability that this function would block for longer than
+ # intended if a `transition` function or context was running while holding
+ # the lock.
+ start = time.time()
+ while not self.__current_state in states:
+ # detect timeout:
+ if time.time() >= start + wait: return False
+ self.notifier.wait(wait)
+ return True
def reset(self):
@@ -202,19 +219,19 @@ class _StateCtx:
self.from_state = from_state
self.to_state = to_state
self.wait = wait
- self._timeout = False
+ self._locked = False
def __enter__(self):
- self.state_machine.lock.acquire()
start = time.time()
- while not self.state_machine[ self.from_state ]:
+ while not self.state_machine[ self.from_state ] or not self.state_machine.lock.acquire(False):
# detect timeout:
if time.time() >= start + self.wait:
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
- self._timeout = True # to indicate we should not transition
return False
- self.state_machine.lock.wait(self.wait)
+ self.state_machine.notifier.wait(self.wait)
+ self._locked = True # lock has been acquired at this point
+ self.state_machine.notifier.clear()
logging.debug('StateMachine entered context in state: %s',
self.state_machine.current_state() )
return True
@@ -222,13 +239,16 @@ class _StateCtx:
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is not None:
logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
- self.state_machine.current_state(), exc_type.__name__, exc_val )
- elif not self._timeout:
- logging.debug(' ==== TRANSITION %s -> %s',
- self.state_machine.current_state(), self.to_state)
- self.state_machine._set_state( self.to_state )
-
- self.state_machine.lock.notify_all()
- self.state_machine.lock.release()
+ self.state_machine.current_state(), exc_type.__name__, exc_val )
+
+ if self._locked:
+ if exc_val is None:
+ logging.debug(' ==== TRANSITION %s -> %s',
+ self.state_machine.current_state(), self.to_state)
+ self.state_machine._set_state( self.to_state )
+
+ self.state_machine.notifier.set()
+ self.state_machine.lock.release()
+
return False # re-raise any exception
diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py
index e44b8e48..0046dd02 100644
--- a/tests/test_statemachine.py
+++ b/tests/test_statemachine.py
@@ -256,6 +256,73 @@ class testStateMachine(unittest.TestCase):
self.assertTrue( s['three'] )
+ def testTransitionsDontUnintentionallyBlock(self):
+ '''
+ There was a bug where a long-running transition (e.g. one with a 'func'
+ arg or a `transition_ctx` call would cause any `transition` or `ensure`
+ call to block since the lock is acquired before checking the current
+ state. Attempts to acquire the mutex need to be non-blocking so when a
+ timeout is _not_ given, the caller can return immediately. At the same
+ time, threads that _do_ want to wait need the ability to be notified
+ (to avoid waiting beyond when the lock is released) so we've moved to a
+ combination of a plain-ol `threading.Lock` to act as mutex, and a
+ `threading.Event` to perform notification for threads who choose to wait.
+ '''
+
+ s = sm.StateMachine(('one','two','three'))
+
+ with s.transition_ctx('two','three') as result:
+ self.failIf( result )
+ self.assertTrue( s['one'] )
+ self.failIf( s.current_state in ('two','three') )
+
+ self.assertTrue( s['one'] )
+
+ statuses = {'t1':"not started",
+ 't2':'not started'}
+
+ def t1():
+ print 'thread 1 started'
+ # no wait, so this should 'return False' immediately.
+ self.failIf( s.transition('two','three') )
+ statuses['t1'] = 'complete'
+ print 'thread 1 transitioned'
+
+ def t2():
+ print 'thread 2 started'
+ self.failIf( s['two'] )
+ self.failIf( s['three'] )
+ # we want this thread to acquire the lock, but for
+ # the second thread not to wait on the first.
+ with s.transition_ctx('one','two', 10) as locked:
+ statuses['t2'] = 'started'
+ print 'thread 2 has entered context'
+ self.assertTrue( locked )
+ # give thread1 a chance to complete while this
+ # thread still owns the lock
+ time.sleep(5)
+ self.assertTrue( s['two'] )
+ statuses['t2'] = 'complete'
+
+ t1 = threading.Thread(target=t1)
+ t2 = threading.Thread(target=t2)
+
+ t2.start() # this should acquire the lock
+ time.sleep(.2)
+ self.assertEqual( 'started', statuses['t2'] )
+ t1.start() # but it shouldn't prevent thread 1 from completing
+ time.sleep(1)
+
+ self.assertEqual( 'complete', statuses['t1'] )
+
+ t1.join()
+ t2.join()
+
+ self.assertEqual( 'complete', statuses['t2'] )
+
+ self.assertTrue( s['two'] )
+
+
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
if __name__ == '__main__': unittest.main()