summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sleekxmpp/xmlstream/statemachine.py139
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py90
-rw-r--r--tests/test_statemachine.py116
3 files changed, 254 insertions, 91 deletions
diff --git a/sleekxmpp/xmlstream/statemachine.py b/sleekxmpp/xmlstream/statemachine.py
index fb7d1508..c5f51765 100644
--- a/sleekxmpp/xmlstream/statemachine.py
+++ b/sleekxmpp/xmlstream/statemachine.py
@@ -7,53 +7,124 @@
"""
from __future__ import with_statement
import threading
+import time
+import logging
class StateMachine(object):
- def __init__(self, states=[], groups=[]):
- self.lock = threading.Lock()
- self.__state = {}
- self.__default_state = {}
- self.__group = {}
+ def __init__(self, states=[]):
+ self.lock = threading.Condition(threading.RLock())
+ self.__states= []
self.addStates(states)
- self.addGroups(groups)
+ self.__default_state = self.__states[0]
+ self.__current_state = self.__default_state
def addStates(self, states):
with self.lock:
for state in states:
- if state in self.__state or state in self.__group:
- raise IndexError("The state or group '%s' is already in the StateMachine." % state)
- self.__state[state] = states[state]
- self.__default_state[state] = states[state]
+ if state in self.__states:
+ raise IndexError("The state '%s' is already in the StateMachine." % state)
+ self.__states.append( state )
- def addGroups(self, groups):
- with self.lock:
- for gstate in groups:
- if gstate in self.__state or gstate in self.__group:
- raise IndexError("The key or group '%s' is already in the StateMachine." % gstate)
- for state in groups[gstate]:
- if state in self.__state:
- raise IndexError("The group %s contains a key %s which is not set in the StateMachine." % (gstate, state))
- self.__group[gstate] = groups[gstate]
- def set(self, state, status):
+ def transition(self, from_state, to_state, wait=0.0):
+ '''
+ Transition from the given `from_state` to the given `to_state`.
+ This method will return `True` if the state machine is now in `to_state`. It
+ will return `False` if a timeout occurred the transition did not occur.
+ If `wait` is 0 (the default,) this method returns immediately if the state machine
+ is not in `from_state`.
+
+ If you want the thread to block and transition once the state machine to enters
+ `from_state`, set `wait` to a non-negative value. Note there is no 'block
+ indefinitely' flag since this leads to deadlock. If you want to wait indefinitely,
+ choose a reasonable value for `wait` (e.g. 20 seconds) and do so in a while loop like so:
+
+ ::
+
+ while not thread_should_exit and not state_machine.transition('disconnected', 'connecting', wait=20 ):
+ pass # timeout will occur every 20s unless transition occurs
+ if thread_should_exit: return
+ # perform actions here after successful transition
+
+ This allows the thread to be interrupted by setting `thread_should_exit=True`
+ '''
+
+ return self.transition_any( (from_state,), to_state, wait=wait )
+
+ def transition_any(self, from_states, to_state, wait=0.0):
+ '''
+ Transition from any of the given `from_states` to the given `to_state`.
+ '''
+
with self.lock:
- if state in self.__state:
- self.__state[state] = bool(status)
+ for state in from_states:
+ if isinstance(state,tuple) or isinstance(state,list):
+ raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) )
+ if not state in self.__states:
+ raise ValueError( "StateMachine does not contain from_state %s." % state )
+ if not to_state in self.__states:
+ raise ValueError( "StateMachine does not contain to_state %s." % to_state )
+
+ start = time.time()
+ while not self.__current_state in from_states:
+ # detect timeout:
+ if time.time() >= start + wait: return False
+ self.lock.wait(wait)
+
+ if self.__current_state in from_states: # should always be True due to lock
+ logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
+ self.__current_state = to_state
+ self.lock.notifyAll()
+ return True
else:
- raise KeyError("StateMachine does not contain state %s." % state)
+ logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
+ return False
+
+
+ def ensure(self, state, wait=0.0):
+ '''
+ Ensure the state machine is currently in `state`, or wait until it enters `state`.
+ '''
+ return self.ensure_any( (state,), wait=wait )
+
+ def ensure_any(self, states, wait=0.0):
+ '''
+ Ensure we are currently in one of the given `states`
+ '''
+ with self.lock:
+ for state in states:
+ if isinstance(state,tuple) or isinstance(state,list):
+ raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) )
+ if not state in self.__states:
+ raise ValueError( "StateMachine does not contain state %s." % state )
+
+ start = time.time()
+ while not self.__current_state in states:
+ # detect timeout:
+ if time.time() >= start + wait: return False
+ self.lock.wait(wait)
+ return self.__current_state in states # should always be True due to lock
+
- def __getitem__(self, key):
- if key in self.__group:
- for state in self.__group[key]:
- if not self.__state[state]:
- return False
- return True
- return self.__state[key]
+ def reset(self):
+ # TODO need to lock before calling this?
+ self.transition(self.__current_state, self._default_state)
- def __getattr__(self, attr):
- return self.__getitem__(attr)
+
+ def __getitem__(self, state):
+ '''
+ Non-blocking, non-synchronized test to determine if we are in the given state.
+ Use `StateMachine.ensure(state)` to wait until the machine enters a certain state.
+ '''
+ return self.__current_state == state
- def reset(self):
- self.__state = self.__default_state
+ def __enter__(self):
+ self.lock.acquire()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.lock.nofityAll()
+ self.lock.release()
+ return False # re-raise any exception
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index fd307a5c..3bcb3412 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -53,8 +53,9 @@ class XMLStream(object):
global ssl_support
self.ssl_support = ssl_support
self.escape_quotes = escape_quotes
- self.state = statemachine.StateMachine()
- self.state.addStates({'connected':False, 'is client':False, 'ssl':False, 'tls':False, 'reconnect':True, 'processing':False}) #set initial states
+ self.state = statemachine.StateMachine(('disconnected','connecting',
+ 'connected'))
+ self.should_reconnect = True
self.setSocket(socket)
self.address = (host, int(port))
@@ -86,21 +87,21 @@ class XMLStream(object):
def setSocket(self, socket):
"Set the socket"
self.socket = socket
- if socket is not None:
+ if socket is not None and self.state.transition('disconnected','connecting'):
self.filesocket = socket.makefile('rb', 0) # ElementTree.iterparse requires a file. 0 buffer files have to be binary
- self.state.set('connected', True)
-
+ self.state.transition('connecting','connected')
def setFileSocket(self, filesocket):
self.filesocket = filesocket
- def connect(self, host='', port=0, use_ssl=False, use_tls=True):
+ def connect(self, host='', port=0, use_ssl=None, use_tls=None):
"Link to connectTCP"
- return self.connectTCP(host, port, use_ssl, use_tls)
+ if self.state.transition('disconnected', 'connecting'):
+ return self.connectTCP(host, port, use_ssl, use_tls)
def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True):
"Connect and create socket"
- while reattempt and not self.state['connected']:
+ while reattempt and not self.state['connected']: # the self.state part is redundant.
logging.debug('connecting....')
try:
if host and port:
@@ -122,7 +123,8 @@ class XMLStream(object):
try:
self.socket.connect(self.address)
self.filesocket = self.socket.makefile('rb', 0)
- self.state.set('connected', True)
+ if not self.state.transition('connecting','connected'):
+ logging.error( "State transition error!!!! Shouldn't have happened" )
logging.debug('connect complete.')
return True
except socket.error as serr:
@@ -182,7 +184,7 @@ class XMLStream(object):
"Start processing the socket."
logging.debug('Process thread starting...')
while self.run:
- self.state.set('processing', True)
+ if not self.state.ensure('connected',wait=2): continue
try:
self.sendRaw(self.stream_header)
while self.run and self.__readXML(): pass
@@ -196,38 +198,16 @@ class XMLStream(object):
logging.debug("Restarting stream...")
continue # DON'T re-initialize the stream -- this exception is sent
# specifically when we've initialized TLS and need to re-send the <stream> header.
- except KeyboardInterrupt:
- logging.debug("Keyboard Escape Detected")
- self.state.set('processing', False)
- self.state.set('reconnect', False)
- self.disconnect()
- # TODO this is probably not necessary...
- self.eventqueue.put(('quit', None, None))
- return
- except SystemExit:
- # TODO shouldn't this be the same as KeyboardInterrupt????
+ except (KeyboardInterrupt, SystemExit):
+ logging.debug("System interrupt detected")
+ self.shutdown()
self.eventqueue.put(('quit', None, None))
- return
except:
logging.exception('Unexpected error in RCV thread')
- if not self.state.reconnect:
- return
- else:
- logging.debug('reconnecting...')
- self.state.set('processing', False)
+ if self.should_reconnect:
self.disconnect(reconnect=True)
- # TODO the individual exception handlers above already handle reconnect!
- # Why are we attempting to do it again down here???
-# if self.state['reconnect']:
-# self.state.set('connected', False)
- self.state.set('processing', False)
-# self.reconnect()
-# else:
-# TODO I think this is getting queued, and when the eventRunner comes back online after
-# reconnect, it immediately processes a 'quit' event and exits again, meanwhile the
-# rest of the client is just starting to connect and process the incoming event stream!!!
-# self.eventqueue.put(('quit', None, None))
- logging.debug('Quitting Process thread')
+
+ logging.debug('Quitting Process thread')
def __readXML(self):
"Parses the incoming stream, adding to xmlin queue as it goes"
@@ -246,7 +226,6 @@ class XMLStream(object):
edepth += -1
if edepth == 0 and event == b'end':
# what is this case exactly? Premature EOF?
- #self.disconnect(reconnect=self.state['reconnect'])
logging.debug("Ending readXML loop")
return False
elif edepth == 1:
@@ -261,9 +240,8 @@ class XMLStream(object):
def _sendThread(self):
logging.debug('send thread starting...')
while self.run:
- if not self.state['connected']:
- logging.warning("Not connected yet...")
- time.sleep(1)
+ if not self.state.ensure('connected',wait=2): continue
+
data = None
try:
data = self.sendqueue.get(True,10)
@@ -272,7 +250,7 @@ class XMLStream(object):
except queue.Empty:
logging.debug('nothing on send queue')
except socket.timeout:
- # this is to prevent hanging
+ # this is to prevent a thread blocked indefinitely
logging.debug('timeout sending packet data')
except:
logging.warning("Failed to send %s" % data)
@@ -282,9 +260,7 @@ class XMLStream(object):
# the same thing concurrently. Oops! The safer option would be to throw
# some sort of event that could be handled by a common thread or the reader
# thread to perform reconnect and then re-initialize the handler threads as well.
- if self.state.reconnect:
- logging.debug('Reconnecting...')
- traceback.print_exc()
+ if self.should_reconnect:
self.disconnect(reconnect=True)
def sendRaw(self, data):
@@ -292,8 +268,7 @@ class XMLStream(object):
return True
def disconnect(self, reconnect=False):
- self.state.set('reconnect', reconnect)
- if not self.state['connected']:
+ if not self.state.transition('connected','disconnected'):
logging.warning("Already disconnected.")
return
logging.debug("Disconnecting...")
@@ -301,10 +276,7 @@ class XMLStream(object):
time.sleep(5)
#send end of stream
#wait for end of stream back
- self.run = False
- self.scheduler.run = False
try:
- self.state.set('connected',False)
# self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
except socket.error as (errno,strerror):
@@ -312,13 +284,17 @@ class XMLStream(object):
try:
self.filesocket.close()
except socket.error as (errno,strerror):
- logging.exception("Error closing filesocket.")
+ logging.exception("Error closing filesocket.")
+
+ if reconnect: self.connect()
- def reconnect(self):
- self.state.set('tls',False)
- self.state.set('ssl',False)
- time.sleep(1)
- self.connect()
+ def shutdown(self):
+ '''
+ Disconnects and shuts down all event threads.
+ '''
+ self.disconnect()
+ self.run = False
+ self.scheduler.run = False
def incoming_filter(self, xmlobj):
return xmlobj
diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py
new file mode 100644
index 00000000..6749c8de
--- /dev/null
+++ b/tests/test_statemachine.py
@@ -0,0 +1,116 @@
+import unittest
+import time, threading
+
+if __name__ == '__main__':
+ import sys, os
+ sys.path.insert(0, os.getcwd())
+ import sleekxmpp.xmlstream.statemachine as sm
+
+
+class testStateMachine(unittest.TestCase):
+
+ def setUp(self): pass
+
+
+ def testDefaults(self):
+ "Test ensure transitions occur correctly in a single thread"
+ s = sm.StateMachine(('one','two','three'))
+# self.assertTrue(s.one)
+ self.assertTrue(s['one'])
+# self.failIf(s.two)
+ self.failIf(s['two'])
+ try:
+ s.booga
+ self.fail('s.booga is an invalid state and should throw an exception!')
+ except: pass #expected exception
+
+
+ def testTransitions(self):
+ "Test ensure transitions occur correctly in a single thread"
+ s = sm.StateMachine(('one','two','three'))
+# self.assertTrue(s.one)
+
+ self.assertTrue( s.transition('one', 'two') )
+# self.assertTrue( s.two )
+ self.assertTrue( s['two'] )
+# self.failIf( s.one )
+ self.failIf( s['one'] )
+
+ self.assertTrue( s.transition('two', 'three') )
+ self.assertTrue( s['three'] )
+ self.failIf( s['two'] )
+
+ self.assertTrue( s.transition('three', 'one') )
+ self.assertTrue( s['one'] )
+ self.failIf( s['three'] )
+
+ # should return False immediately w/ no wait:
+ self.failIf( s.transition('three', 'one') )
+ self.assertTrue( s['one'] )
+ self.failIf( s['three'] )
+
+ # test fail condition w/ a short delay:
+ self.failIf( s.transition('two', 'three') )
+
+ # Ensure bad states are weeded out:
+ try:
+ s.transition('blah', 'three')
+ s.fail('Exception expected')
+ except: pass
+
+ try:
+ s.transition('one', 'blahblah')
+ s.fail('Exception expected')
+ except: pass
+
+
+ def testTransitionsBlocking(self):
+ "Test that transitions block from more than one thread"
+
+ s = sm.StateMachine(('one','two','three'))
+ self.assertTrue(s['one'])
+
+ now = time.time()
+ self.failIf( s.transition('two', 'one', wait=5.0) )
+ self.assertTrue( time.time() > now + 4 )
+ self.assertTrue( time.time() < now + 7 )
+
+ def testThreadedTransitions(self):
+ "Test that transitions are atomic in > one thread"
+
+ s = sm.StateMachine(('one','two','three'))
+ self.assertTrue(s['one'])
+
+ thread_state = {'ready': False, 'transitioned': False}
+ def t1():
+ # this will block until the main thread transitions to 'two'
+ if s['two']:
+ print 'thread has already transitioned!'
+ self.fail()
+ thread_state['ready'] = True
+ print 'Thread is ready'
+ self.assertTrue( s.transition('two','three', wait=20) )
+ print 'transitioned to three!'
+ thread_state['transitioned'] = True
+
+ thread = threading.Thread(target=t1)
+ thread.daemon = True
+ thread.start()
+ start = time.time()
+ while not thread_state['ready']:
+ print 'not ready'
+ if time.time() > start+10: self.fail('Timeout waiting for thread to init!')
+ time.sleep(0.1)
+ time.sleep(0.2) # the thread should be blocking on the 'transition' call at this point.
+ self.failIf( thread_state['transitioned'] ) # ensure it didn't 'go' yet.
+ print 'transitioning to two!'
+ self.assertTrue( s.transition('one','two') )
+ time.sleep(0.2) # second thread should have transitioned now:
+ self.assertTrue( thread_state['transitioned'] )
+
+
+
+
+suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
+
+if __name__ == '__main__': unittest.main()