summaryrefslogtreecommitdiff
path: root/tests/test_statemachine.py
blob: 0046dd022eaafaa0cb99269ecb2b39aa3289a59e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import unittest
import time, threading, random, functools

if __name__ == '__main__': 
	import sys, os
	sys.path.insert(0, os.getcwd())
	import sleekxmpp.xmlstream.statemachine as sm


class testStateMachine(unittest.TestCase):

	def setUp(self): pass
	
	
	def testDefaults(self):
		"Test ensure transitions occur correctly in a single thread"
		s = sm.StateMachine(('one','two','three'))
		self.assertTrue(s['one'])
		self.failIf(s['two'])
		try:
			s['booga']
			self.fail('s.booga is an invalid state and should throw an exception!')
		except: pass #expected exception

		# just make sure __str__ works, no reason to test its exact value:
		print str(s)


	def testTransitions(self):
		"Test ensure transitions occur correctly in a single thread"
		s = sm.StateMachine(('one','two','three'))

		self.assertTrue( s.transition('one', 'two') )
		self.assertTrue( s['two'] )
		self.failIf( s['one'] )

		self.assertTrue( s.transition('two', 'three') )
		self.assertTrue( s['three'] )
		self.failIf( s['two'] )

		self.assertTrue( s.transition('three', 'one') )
		self.assertTrue( s['one'] )
		self.failIf( s['three'] )

		# should return False immediately w/ no wait:
		self.failIf( s.transition('three', 'one') )
		self.assertTrue( s['one'] )
		self.failIf( s['three'] )

		# test fail condition w/ a short delay:
		self.failIf( s.transition('two', 'three') )

		# Ensure bad states are weeded out: 
		try: 
			s.transition('blah', 'three')
			s.fail('Exception expected')
		except: pass

		try: 
			s.transition('one', 'blahblah')
			s.fail('Exception expected')
		except: pass


	def testTransitionsBlocking(self):
		"Test that transitions block from more than one thread"

		s = sm.StateMachine(('one','two','three'))
		self.assertTrue(s['one'])

		now = time.time()
		self.failIf( s.transition('two', 'one', wait=5.0) )
		self.assertTrue( time.time() > now + 4 )
		self.assertTrue( time.time() < now + 7 )

	def testThreadedTransitions(self):
		"Test that transitions are atomic in > one thread"

		s = sm.StateMachine(('one','two','three'))
		self.assertTrue(s['one'])

		thread_state = {'ready': False, 'transitioned': False}
		def t1():
			if s['two']:
				print 'thread has already transitioned!'
				self.fail()
			thread_state['ready'] = True
			print 'Thread is ready'
			# this will block until the main thread transitions to 'two'
			self.assertTrue( s.transition('two','three', wait=20) )
			print 'transitioned to three!'
			thread_state['transitioned'] = True

		thread = threading.Thread(target=t1)
		thread.daemon = True
		thread.start()
		start = time.time()
		while not thread_state['ready']:
			print 'not ready'
			if time.time() > start+10: self.fail('Timeout waiting for thread to init!')
			time.sleep(0.1)
		time.sleep(0.2) # the thread should be blocking on the 'transition' call at this point.
		self.failIf( thread_state['transitioned'] ) # ensure it didn't 'go' yet.
		print 'transitioning to two!'
		self.assertTrue( s.transition('one','two') )
		time.sleep(0.2) # second thread should have transitioned now:
		self.assertTrue( thread_state['transitioned'] )
		

	def testForRaceCondition(self):
		"""Attempt to allow two threads to perform the same transition; 
		only one should ever make it."""

		s = sm.StateMachine(('one','two','three'))

		def t1(num):
			while True:
				if not trigger['go'] or thread_state[num] in (True,False):
					time.sleep( random.random()/100 ) # < .01s
					if thread_state[num] == 'quit': break
					continue

				thread_state[num] = s.transition('one','two' )
#				print '-',

		thread_count = 20
		threads = []
		thread_state = {}
		def reset(): 
			for c in range(thread_count): thread_state[c] = "reset"
		trigger = {'go':False} # use of a plain boolean seems to be non-volatile between threads.

		for c in range(thread_count):
			thread_state[c] = "reset"
			thread = threading.Thread( target= functools.partial(t1,c) )
			threads.append( thread )
			thread.daemon = True
			thread.start()

		for x in range(100): # this will take 10s to execute
#			print "+",
			trigger['go'] = True
			time.sleep(.1)
			trigger['go'] = False
			winners = 0
			for (num, state) in thread_state.items():
				if state == True: winners = winners +1
				elif state != False: raise Exception( "!%d!%s!" % (num,state) )
			
			self.assertEqual( 1, winners, "Expected one winner! %d" % winners )
			self.assertTrue( s.ensure('two') )
			self.assertTrue( s.transition('two','one') ) # return to the first state.
			reset()

		# now let the threads quit gracefully:
		for c in range(thread_count): thread_state[c] = 'quit'
		time.sleep(2)


	def testTransitionFunctions(self):
		"test that a `func` argument allows or blocks the transition correctly."

		s = sm.StateMachine(('one','two','three'))
		
		def alwaysFalse(): return False
		def alwaysTrue(): return True

		self.failIf( s.transition('one','two', func=alwaysFalse) )
		self.assertTrue(s['one'])
		self.failIf(s['two'])

		self.assertTrue( s.transition('one','two', func=alwaysTrue) )
		self.failIf(s['one'])
		self.assertTrue(s['two'])


	def testTransitionFuncException(self):
		"if a transition function throws an exeption, ensure we're in a sane state"

		s = sm.StateMachine(('one','two','three'))
		
		def alwaysException(): raise Exception('whups!')

		try:
			self.failIf( s.transition('one','two', func=alwaysException) )
			self.fail("exception should have been thrown")
		except: pass #expected exception

		self.assertTrue(s['one'])
		self.failIf(s['two'])

		# ensure a subsequent attempt completes normally:
		self.assertTrue( s.transition('one','two') )
		self.failIf(s['one'])
		self.assertTrue(s['two'])


	def testContextManager(self):

		s = sm.StateMachine(('one','two','three'))

		with s.transition_ctx('one','two'):
			self.assertTrue( s['one'] )
			self.failIf( s['two'] )

		#successful transition b/c no exception was thrown
		self.assertTrue( s['two'] )
		self.failIf( s['one'] )

		# failed transition because exception is thrown:
		try:
			with s.transition_ctx('two','three'):
				raise Exception("boom!")
			self.fail('exception expected')
		except: pass

		self.failIf( s.current_state() in ('one','three') )
		self.assertTrue( s['two'] )

	def testCtxManagerTransitionFailure(self):

		s = sm.StateMachine(('one','two','three'))

		with s.transition_ctx('two','three') as result:
			self.failIf( result )
			self.assertTrue( s['one'] )
			self.failIf( s.current_state in ('two','three') )

		self.assertTrue( s['one'] )
		
		def r1():
			print 'thread 1 started'
			self.assertTrue( s.transition('one','two') )
			print 'thread 1 transitioned'

		def r2():
			print 'thread 2 started'
			self.failIf( s['two'] )
			with s.transition_ctx('two','three', 10) as result:
				self.assertTrue( result )
				self.assertTrue( s['two'] )
				print 'thread 2 will transition on exit from the context manager...'
			self.assertTrue( s['three'] )
			print 'transitioned to %s' % s.current_state()

		t1 = threading.Thread(target=r1)
		t2 = threading.Thread(target=r2)

		t2.start() # this should block until r1 goes
		time.sleep(1)
		t1.start()

		t1.join()
		t2.join()

		self.assertTrue( s['three'] )


	def testTransitionsDontUnintentionallyBlock(self):
		'''
		There was a bug where a long-running transition (e.g. one with a 'func'
		arg or a `transition_ctx` call would cause any `transition` or `ensure`
		call to block since the lock is acquired before checking the current
		state.  Attempts to acquire the mutex need to be non-blocking so when a
		timeout is _not_ given, the caller can return immediately.  At the same
		time, threads that _do_ want to wait need the ability to be notified
		(to avoid waiting beyond when the lock is released) so we've moved to a 
		combination of a plain-ol `threading.Lock` to act as mutex, and a 
		`threading.Event` to perform notification for threads who choose to wait.
		'''

		s = sm.StateMachine(('one','two','three'))

		with s.transition_ctx('two','three') as result:
			self.failIf( result )
			self.assertTrue( s['one'] )
			self.failIf( s.current_state in ('two','three') )

		self.assertTrue( s['one'] )
		
		statuses = {'t1':"not started",
					't2':'not started'}

		def t1():
			print 'thread 1 started'
			# no wait, so this should 'return False' immediately.
			self.failIf( s.transition('two','three') )
			statuses['t1'] = 'complete'
			print 'thread 1 transitioned'

		def t2():
			print 'thread 2 started'
			self.failIf( s['two'] )
			self.failIf( s['three'] )
			# we want this thread to acquire the lock, but for 
			# the second thread not to wait on the first.
			with s.transition_ctx('one','two', 10) as locked:
				statuses['t2'] = 'started'
				print 'thread 2 has entered context'
				self.assertTrue( locked )
				# give thread1 a chance to complete while this 
				# thread still owns the lock
				time.sleep(5) 
			self.assertTrue( s['two'] )
			statuses['t2'] = 'complete'

		t1 = threading.Thread(target=t1)
		t2 = threading.Thread(target=t2)

		t2.start() # this should acquire the lock
		time.sleep(.2)
		self.assertEqual( 'started', statuses['t2'] )
		t1.start() # but it shouldn't prevent thread 1 from completing
		time.sleep(1)

		self.assertEqual( 'complete', statuses['t1'] )

		t1.join()
		t2.join()

		self.assertEqual( 'complete', statuses['t2'] )

		self.assertTrue( s['two'] )


suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)

if __name__ == '__main__': unittest.main()