From 7930ed22f2371ba3405f9644f427bec9554d2a15 Mon Sep 17 00:00:00 2001 From: Thom Nichols Date: Wed, 2 Jun 2010 12:39:54 -0400 Subject: overhauled state machine. Now allows for atomic transitions. Next step: atomic function calls (and maybe 'handlers') on state transition. --- sleekxmpp/xmlstream/statemachine.py | 139 +++++++++++++++++++++++++++--------- sleekxmpp/xmlstream/xmlstream.py | 90 +++++++++-------------- tests/test_statemachine.py | 116 ++++++++++++++++++++++++++++++ 3 files changed, 254 insertions(+), 91 deletions(-) create mode 100644 tests/test_statemachine.py diff --git a/sleekxmpp/xmlstream/statemachine.py b/sleekxmpp/xmlstream/statemachine.py index fb7d1508..c5f51765 100644 --- a/sleekxmpp/xmlstream/statemachine.py +++ b/sleekxmpp/xmlstream/statemachine.py @@ -7,53 +7,124 @@ """ from __future__ import with_statement import threading +import time +import logging class StateMachine(object): - def __init__(self, states=[], groups=[]): - self.lock = threading.Lock() - self.__state = {} - self.__default_state = {} - self.__group = {} + def __init__(self, states=[]): + self.lock = threading.Condition(threading.RLock()) + self.__states= [] self.addStates(states) - self.addGroups(groups) + self.__default_state = self.__states[0] + self.__current_state = self.__default_state def addStates(self, states): with self.lock: for state in states: - if state in self.__state or state in self.__group: - raise IndexError("The state or group '%s' is already in the StateMachine." % state) - self.__state[state] = states[state] - self.__default_state[state] = states[state] + if state in self.__states: + raise IndexError("The state '%s' is already in the StateMachine." % state) + self.__states.append( state ) - def addGroups(self, groups): - with self.lock: - for gstate in groups: - if gstate in self.__state or gstate in self.__group: - raise IndexError("The key or group '%s' is already in the StateMachine." % gstate) - for state in groups[gstate]: - if state in self.__state: - raise IndexError("The group %s contains a key %s which is not set in the StateMachine." % (gstate, state)) - self.__group[gstate] = groups[gstate] - def set(self, state, status): + def transition(self, from_state, to_state, wait=0.0): + ''' + Transition from the given `from_state` to the given `to_state`. + This method will return `True` if the state machine is now in `to_state`. It + will return `False` if a timeout occurred the transition did not occur. + If `wait` is 0 (the default,) this method returns immediately if the state machine + is not in `from_state`. + + If you want the thread to block and transition once the state machine to enters + `from_state`, set `wait` to a non-negative value. Note there is no 'block + indefinitely' flag since this leads to deadlock. If you want to wait indefinitely, + choose a reasonable value for `wait` (e.g. 20 seconds) and do so in a while loop like so: + + :: + + while not thread_should_exit and not state_machine.transition('disconnected', 'connecting', wait=20 ): + pass # timeout will occur every 20s unless transition occurs + if thread_should_exit: return + # perform actions here after successful transition + + This allows the thread to be interrupted by setting `thread_should_exit=True` + ''' + + return self.transition_any( (from_state,), to_state, wait=wait ) + + def transition_any(self, from_states, to_state, wait=0.0): + ''' + Transition from any of the given `from_states` to the given `to_state`. + ''' + with self.lock: - if state in self.__state: - self.__state[state] = bool(status) + for state in from_states: + if isinstance(state,tuple) or isinstance(state,list): + raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) ) + if not state in self.__states: + raise ValueError( "StateMachine does not contain from_state %s." % state ) + if not to_state in self.__states: + raise ValueError( "StateMachine does not contain to_state %s." % to_state ) + + start = time.time() + while not self.__current_state in from_states: + # detect timeout: + if time.time() >= start + wait: return False + self.lock.wait(wait) + + if self.__current_state in from_states: # should always be True due to lock + logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state) + self.__current_state = to_state + self.lock.notifyAll() + return True else: - raise KeyError("StateMachine does not contain state %s." % state) + logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" ) + return False + + + def ensure(self, state, wait=0.0): + ''' + Ensure the state machine is currently in `state`, or wait until it enters `state`. + ''' + return self.ensure_any( (state,), wait=wait ) + + def ensure_any(self, states, wait=0.0): + ''' + Ensure we are currently in one of the given `states` + ''' + with self.lock: + for state in states: + if isinstance(state,tuple) or isinstance(state,list): + raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) ) + if not state in self.__states: + raise ValueError( "StateMachine does not contain state %s." % state ) + + start = time.time() + while not self.__current_state in states: + # detect timeout: + if time.time() >= start + wait: return False + self.lock.wait(wait) + return self.__current_state in states # should always be True due to lock + - def __getitem__(self, key): - if key in self.__group: - for state in self.__group[key]: - if not self.__state[state]: - return False - return True - return self.__state[key] + def reset(self): + # TODO need to lock before calling this? + self.transition(self.__current_state, self._default_state) - def __getattr__(self, attr): - return self.__getitem__(attr) + + def __getitem__(self, state): + ''' + Non-blocking, non-synchronized test to determine if we are in the given state. + Use `StateMachine.ensure(state)` to wait until the machine enters a certain state. + ''' + return self.__current_state == state - def reset(self): - self.__state = self.__default_state + def __enter__(self): + self.lock.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.lock.nofityAll() + self.lock.release() + return False # re-raise any exception diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index fd307a5c..3bcb3412 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -53,8 +53,9 @@ class XMLStream(object): global ssl_support self.ssl_support = ssl_support self.escape_quotes = escape_quotes - self.state = statemachine.StateMachine() - self.state.addStates({'connected':False, 'is client':False, 'ssl':False, 'tls':False, 'reconnect':True, 'processing':False}) #set initial states + self.state = statemachine.StateMachine(('disconnected','connecting', + 'connected')) + self.should_reconnect = True self.setSocket(socket) self.address = (host, int(port)) @@ -86,21 +87,21 @@ class XMLStream(object): def setSocket(self, socket): "Set the socket" self.socket = socket - if socket is not None: + if socket is not None and self.state.transition('disconnected','connecting'): self.filesocket = socket.makefile('rb', 0) # ElementTree.iterparse requires a file. 0 buffer files have to be binary - self.state.set('connected', True) - + self.state.transition('connecting','connected') def setFileSocket(self, filesocket): self.filesocket = filesocket - def connect(self, host='', port=0, use_ssl=False, use_tls=True): + def connect(self, host='', port=0, use_ssl=None, use_tls=None): "Link to connectTCP" - return self.connectTCP(host, port, use_ssl, use_tls) + if self.state.transition('disconnected', 'connecting'): + return self.connectTCP(host, port, use_ssl, use_tls) def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True): "Connect and create socket" - while reattempt and not self.state['connected']: + while reattempt and not self.state['connected']: # the self.state part is redundant. logging.debug('connecting....') try: if host and port: @@ -122,7 +123,8 @@ class XMLStream(object): try: self.socket.connect(self.address) self.filesocket = self.socket.makefile('rb', 0) - self.state.set('connected', True) + if not self.state.transition('connecting','connected'): + logging.error( "State transition error!!!! Shouldn't have happened" ) logging.debug('connect complete.') return True except socket.error as serr: @@ -182,7 +184,7 @@ class XMLStream(object): "Start processing the socket." logging.debug('Process thread starting...') while self.run: - self.state.set('processing', True) + if not self.state.ensure('connected',wait=2): continue try: self.sendRaw(self.stream_header) while self.run and self.__readXML(): pass @@ -196,38 +198,16 @@ class XMLStream(object): logging.debug("Restarting stream...") continue # DON'T re-initialize the stream -- this exception is sent # specifically when we've initialized TLS and need to re-send the header. - except KeyboardInterrupt: - logging.debug("Keyboard Escape Detected") - self.state.set('processing', False) - self.state.set('reconnect', False) - self.disconnect() - # TODO this is probably not necessary... - self.eventqueue.put(('quit', None, None)) - return - except SystemExit: - # TODO shouldn't this be the same as KeyboardInterrupt???? + except (KeyboardInterrupt, SystemExit): + logging.debug("System interrupt detected") + self.shutdown() self.eventqueue.put(('quit', None, None)) - return except: logging.exception('Unexpected error in RCV thread') - if not self.state.reconnect: - return - else: - logging.debug('reconnecting...') - self.state.set('processing', False) + if self.should_reconnect: self.disconnect(reconnect=True) - # TODO the individual exception handlers above already handle reconnect! - # Why are we attempting to do it again down here??? -# if self.state['reconnect']: -# self.state.set('connected', False) - self.state.set('processing', False) -# self.reconnect() -# else: -# TODO I think this is getting queued, and when the eventRunner comes back online after -# reconnect, it immediately processes a 'quit' event and exits again, meanwhile the -# rest of the client is just starting to connect and process the incoming event stream!!! -# self.eventqueue.put(('quit', None, None)) - logging.debug('Quitting Process thread') + + logging.debug('Quitting Process thread') def __readXML(self): "Parses the incoming stream, adding to xmlin queue as it goes" @@ -246,7 +226,6 @@ class XMLStream(object): edepth += -1 if edepth == 0 and event == b'end': # what is this case exactly? Premature EOF? - #self.disconnect(reconnect=self.state['reconnect']) logging.debug("Ending readXML loop") return False elif edepth == 1: @@ -261,9 +240,8 @@ class XMLStream(object): def _sendThread(self): logging.debug('send thread starting...') while self.run: - if not self.state['connected']: - logging.warning("Not connected yet...") - time.sleep(1) + if not self.state.ensure('connected',wait=2): continue + data = None try: data = self.sendqueue.get(True,10) @@ -272,7 +250,7 @@ class XMLStream(object): except queue.Empty: logging.debug('nothing on send queue') except socket.timeout: - # this is to prevent hanging + # this is to prevent a thread blocked indefinitely logging.debug('timeout sending packet data') except: logging.warning("Failed to send %s" % data) @@ -282,9 +260,7 @@ class XMLStream(object): # the same thing concurrently. Oops! The safer option would be to throw # some sort of event that could be handled by a common thread or the reader # thread to perform reconnect and then re-initialize the handler threads as well. - if self.state.reconnect: - logging.debug('Reconnecting...') - traceback.print_exc() + if self.should_reconnect: self.disconnect(reconnect=True) def sendRaw(self, data): @@ -292,8 +268,7 @@ class XMLStream(object): return True def disconnect(self, reconnect=False): - self.state.set('reconnect', reconnect) - if not self.state['connected']: + if not self.state.transition('connected','disconnected'): logging.warning("Already disconnected.") return logging.debug("Disconnecting...") @@ -301,10 +276,7 @@ class XMLStream(object): time.sleep(5) #send end of stream #wait for end of stream back - self.run = False - self.scheduler.run = False try: - self.state.set('connected',False) # self.socket.shutdown(socket.SHUT_RDWR) self.socket.close() except socket.error as (errno,strerror): @@ -312,13 +284,17 @@ class XMLStream(object): try: self.filesocket.close() except socket.error as (errno,strerror): - logging.exception("Error closing filesocket.") + logging.exception("Error closing filesocket.") + + if reconnect: self.connect() - def reconnect(self): - self.state.set('tls',False) - self.state.set('ssl',False) - time.sleep(1) - self.connect() + def shutdown(self): + ''' + Disconnects and shuts down all event threads. + ''' + self.disconnect() + self.run = False + self.scheduler.run = False def incoming_filter(self, xmlobj): return xmlobj diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py new file mode 100644 index 00000000..6749c8de --- /dev/null +++ b/tests/test_statemachine.py @@ -0,0 +1,116 @@ +import unittest +import time, threading + +if __name__ == '__main__': + import sys, os + sys.path.insert(0, os.getcwd()) + import sleekxmpp.xmlstream.statemachine as sm + + +class testStateMachine(unittest.TestCase): + + def setUp(self): pass + + + def testDefaults(self): + "Test ensure transitions occur correctly in a single thread" + s = sm.StateMachine(('one','two','three')) +# self.assertTrue(s.one) + self.assertTrue(s['one']) +# self.failIf(s.two) + self.failIf(s['two']) + try: + s.booga + self.fail('s.booga is an invalid state and should throw an exception!') + except: pass #expected exception + + + def testTransitions(self): + "Test ensure transitions occur correctly in a single thread" + s = sm.StateMachine(('one','two','three')) +# self.assertTrue(s.one) + + self.assertTrue( s.transition('one', 'two') ) +# self.assertTrue( s.two ) + self.assertTrue( s['two'] ) +# self.failIf( s.one ) + self.failIf( s['one'] ) + + self.assertTrue( s.transition('two', 'three') ) + self.assertTrue( s['three'] ) + self.failIf( s['two'] ) + + self.assertTrue( s.transition('three', 'one') ) + self.assertTrue( s['one'] ) + self.failIf( s['three'] ) + + # should return False immediately w/ no wait: + self.failIf( s.transition('three', 'one') ) + self.assertTrue( s['one'] ) + self.failIf( s['three'] ) + + # test fail condition w/ a short delay: + self.failIf( s.transition('two', 'three') ) + + # Ensure bad states are weeded out: + try: + s.transition('blah', 'three') + s.fail('Exception expected') + except: pass + + try: + s.transition('one', 'blahblah') + s.fail('Exception expected') + except: pass + + + def testTransitionsBlocking(self): + "Test that transitions block from more than one thread" + + s = sm.StateMachine(('one','two','three')) + self.assertTrue(s['one']) + + now = time.time() + self.failIf( s.transition('two', 'one', wait=5.0) ) + self.assertTrue( time.time() > now + 4 ) + self.assertTrue( time.time() < now + 7 ) + + def testThreadedTransitions(self): + "Test that transitions are atomic in > one thread" + + s = sm.StateMachine(('one','two','three')) + self.assertTrue(s['one']) + + thread_state = {'ready': False, 'transitioned': False} + def t1(): + # this will block until the main thread transitions to 'two' + if s['two']: + print 'thread has already transitioned!' + self.fail() + thread_state['ready'] = True + print 'Thread is ready' + self.assertTrue( s.transition('two','three', wait=20) ) + print 'transitioned to three!' + thread_state['transitioned'] = True + + thread = threading.Thread(target=t1) + thread.daemon = True + thread.start() + start = time.time() + while not thread_state['ready']: + print 'not ready' + if time.time() > start+10: self.fail('Timeout waiting for thread to init!') + time.sleep(0.1) + time.sleep(0.2) # the thread should be blocking on the 'transition' call at this point. + self.failIf( thread_state['transitioned'] ) # ensure it didn't 'go' yet. + print 'transitioning to two!' + self.assertTrue( s.transition('one','two') ) + time.sleep(0.2) # second thread should have transitioned now: + self.assertTrue( thread_state['transitioned'] ) + + + + +suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine) + +if __name__ == '__main__': unittest.main() -- cgit v1.2.3