summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore4
-rw-r--r--conn_tests/test_pubsubjobs.py171
-rw-r--r--conn_tests/test_pubsubserver.py1
-rw-r--r--sleekxmpp/__init__.py149
-rw-r--r--sleekxmpp/basexmpp.py39
-rw-r--r--sleekxmpp/plugins/jobs.py44
-rw-r--r--sleekxmpp/plugins/stanza_pubsub.py45
-rw-r--r--sleekxmpp/plugins/xep_0004.py1
-rw-r--r--sleekxmpp/plugins/xep_0030.py405
-rw-r--r--sleekxmpp/plugins/xep_0060.py10
-rw-r--r--sleekxmpp/stanza/error.py10
-rw-r--r--sleekxmpp/stanza/iq.py14
-rw-r--r--sleekxmpp/xmlstream/handler/base.py2
-rw-r--r--sleekxmpp/xmlstream/handler/callback.py4
-rw-r--r--sleekxmpp/xmlstream/scheduler.py87
-rw-r--r--sleekxmpp/xmlstream/stanzabase.py22
-rw-r--r--sleekxmpp/xmlstream/statemachine.py245
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py323
-rw-r--r--tests/test_disco.py155
-rw-r--r--tests/test_pubsubstanzas.py15
-rw-r--r--tests/test_statemachine.py261
21 files changed, 1654 insertions, 353 deletions
diff --git a/.gitignore b/.gitignore
index 0fe2c40e..6257bbf6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,6 @@
*.pyc
+.project
build/
+*.swp
+.pydevproject
+.settings
diff --git a/conn_tests/test_pubsubjobs.py b/conn_tests/test_pubsubjobs.py
new file mode 100644
index 00000000..edf22ccc
--- /dev/null
+++ b/conn_tests/test_pubsubjobs.py
@@ -0,0 +1,171 @@
+import logging
+import sleekxmpp
+from optparse import OptionParser
+from xml.etree import cElementTree as ET
+import os
+import time
+import sys
+import unittest
+import sleekxmpp.plugins.xep_0004
+from sleekxmpp.xmlstream.matcher.stanzapath import StanzaPath
+from sleekxmpp.xmlstream.handler.waiter import Waiter
+try:
+ import configparser
+except ImportError:
+ import ConfigParser as configparser
+try:
+ import queue
+except ImportError:
+ import Queue as queue
+
+class TestClient(sleekxmpp.ClientXMPP):
+ def __init__(self, jid, password):
+ sleekxmpp.ClientXMPP.__init__(self, jid, password)
+ self.add_event_handler("session_start", self.start)
+ #self.add_event_handler("message", self.message)
+ self.waitforstart = queue.Queue()
+
+ def start(self, event):
+ self.getRoster()
+ self.sendPresence()
+ self.waitforstart.put(True)
+
+
+class TestPubsubServer(unittest.TestCase):
+ statev = {}
+
+ def __init__(self, *args, **kwargs):
+ unittest.TestCase.__init__(self, *args, **kwargs)
+
+ def setUp(self):
+ pass
+
+ def test001getdefaultconfig(self):
+ """Get the default node config"""
+ self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode2')
+ self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode3')
+ self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode4')
+ self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode5')
+ result = self.xmpp1['xep_0060'].getNodeConfig(self.pshost)
+ self.statev['defaultconfig'] = result
+ self.failUnless(isinstance(result, sleekxmpp.plugins.xep_0004.Form))
+
+ def test002createdefaultnode(self):
+ """Create a node without config"""
+ self.failUnless(self.xmpp1['xep_0060'].create_node(self.pshost, 'testnode1'))
+
+ def test003deletenode(self):
+ """Delete recently created node"""
+ self.failUnless(self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode1'))
+
+ def test004createnode(self):
+ """Create a node with a config"""
+ self.statev['defaultconfig'].field['pubsub#access_model'].setValue('open')
+ self.statev['defaultconfig'].field['pubsub#notify_retract'].setValue(True)
+ self.statev['defaultconfig'].field['pubsub#persist_items'].setValue(True)
+ self.statev['defaultconfig'].field['pubsub#presence_based_delivery'].setValue(True)
+ p = self.xmpp2.Presence()
+ p['to'] = self.pshost
+ p.send()
+ self.failUnless(self.xmpp1['xep_0060'].create_node(self.pshost, 'testnode2', self.statev['defaultconfig'], ntype='job'))
+
+ def test005reconfigure(self):
+ """Retrieving node config and reconfiguring"""
+ nconfig = self.xmpp1['xep_0060'].getNodeConfig(self.pshost, 'testnode2')
+ self.failUnless(nconfig, "No configuration returned")
+ #print("\n%s ==\n %s" % (nconfig.getValues(), self.statev['defaultconfig'].getValues()))
+ self.failUnless(nconfig.getValues() == self.statev['defaultconfig'].getValues(), "Configuration does not match")
+ self.failUnless(self.xmpp1['xep_0060'].setNodeConfig(self.pshost, 'testnode2', nconfig))
+
+ def test006subscribetonode(self):
+ """Subscribe to node from account 2"""
+ self.failUnless(self.xmpp2['xep_0060'].subscribe(self.pshost, "testnode2"))
+
+ def test007publishitem(self):
+ """Publishing item"""
+ item = ET.Element('{http://netflint.net/protocol/test}test')
+ w = Waiter('wait publish', StanzaPath('message/pubsub_event/items'))
+ self.xmpp2.registerHandler(w)
+ #result = self.xmpp1['xep_0060'].setItem(self.pshost, "testnode2", (('test1', item),))
+ result = self.xmpp1['jobs'].createJob(self.pshost, "testnode2", 'test1', item)
+ msg = w.wait(5) # got to get a result in 5 seconds
+ self.failUnless(msg != False, "Account #2 did not get message event")
+ #result = self.xmpp1['xep_0060'].setItem(self.pshost, "testnode2", (('test2', item),))
+ result = self.xmpp1['jobs'].createJob(self.pshost, "testnode2", 'test2', item)
+ w = Waiter('wait publish2', StanzaPath('message/pubsub_event/items'))
+ self.xmpp2.registerHandler(w)
+ self.xmpp2['jobs'].claimJob(self.pshost, 'testnode2', 'test1')
+ msg = w.wait(5) # got to get a result in 5 seconds
+ self.xmpp2['jobs'].claimJob(self.pshost, 'testnode2', 'test2')
+ self.xmpp2['jobs'].finishJob(self.pshost, 'testnode2', 'test1')
+ self.xmpp2['jobs'].finishJob(self.pshost, 'testnode2', 'test2')
+ print result
+ #need to add check for update
+
+ def test900cleanup(self):
+ "Cleaning up"
+ #self.failUnless(self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode2'), "Could not delete test node.")
+ time.sleep(10)
+
+
+if __name__ == '__main__':
+ #parse command line arguements
+ optp = OptionParser()
+ optp.add_option('-q','--quiet', help='set logging to ERROR', action='store_const', dest='loglevel', const=logging.ERROR, default=logging.INFO)
+ optp.add_option('-d','--debug', help='set logging to DEBUG', action='store_const', dest='loglevel', const=logging.DEBUG, default=logging.INFO)
+ optp.add_option('-v','--verbose', help='set logging to COMM', action='store_const', dest='loglevel', const=5, default=logging.INFO)
+ optp.add_option("-c","--config", dest="configfile", default="config.xml", help="set config file to use")
+ optp.add_option("-n","--nodenum", dest="nodenum", default="1", help="set node number to use")
+ optp.add_option("-p","--pubsub", dest="pubsub", default="1", help="set pubsub host to use")
+ opts,args = optp.parse_args()
+
+ logging.basicConfig(level=opts.loglevel, format='%(levelname)-8s %(message)s')
+
+ #load xml config
+ logging.info("Loading config file: %s" % opts.configfile)
+ config = configparser.RawConfigParser()
+ config.read(opts.configfile)
+
+ #init
+ logging.info("Account 1 is %s" % config.get('account1', 'jid'))
+ xmpp1 = TestClient(config.get('account1','jid'), config.get('account1','pass'))
+ logging.info("Account 2 is %s" % config.get('account2', 'jid'))
+ xmpp2 = TestClient(config.get('account2','jid'), config.get('account2','pass'))
+
+ xmpp1.registerPlugin('xep_0004')
+ xmpp1.registerPlugin('xep_0030')
+ xmpp1.registerPlugin('xep_0060')
+ xmpp1.registerPlugin('xep_0199')
+ xmpp1.registerPlugin('jobs')
+ xmpp2.registerPlugin('xep_0004')
+ xmpp2.registerPlugin('xep_0030')
+ xmpp2.registerPlugin('xep_0060')
+ xmpp2.registerPlugin('xep_0199')
+ xmpp2.registerPlugin('jobs')
+
+ if not config.get('account1', 'server'):
+ # we don't know the server, but the lib can probably figure it out
+ xmpp1.connect()
+ else:
+ xmpp1.connect((config.get('account1', 'server'), 5222))
+ xmpp1.process(threaded=True)
+
+ #init
+ if not config.get('account2', 'server'):
+ # we don't know the server, but the lib can probably figure it out
+ xmpp2.connect()
+ else:
+ xmpp2.connect((config.get('account2', 'server'), 5222))
+ xmpp2.process(threaded=True)
+
+ TestPubsubServer.xmpp1 = xmpp1
+ TestPubsubServer.xmpp2 = xmpp2
+ TestPubsubServer.pshost = config.get('settings', 'pubsub')
+ xmpp1.waitforstart.get(True)
+ xmpp2.waitforstart.get(True)
+ testsuite = unittest.TestLoader().loadTestsFromTestCase(TestPubsubServer)
+
+ alltests_suite = unittest.TestSuite([testsuite])
+ result = unittest.TextTestRunner(verbosity=2).run(alltests_suite)
+ xmpp1.disconnect()
+ xmpp2.disconnect()
diff --git a/conn_tests/test_pubsubserver.py b/conn_tests/test_pubsubserver.py
index d1e2208f..15635b4b 100644
--- a/conn_tests/test_pubsubserver.py
+++ b/conn_tests/test_pubsubserver.py
@@ -5,7 +5,6 @@ from xml.etree import cElementTree as ET
import os
import time
import sys
-import thread
import unittest
import sleekxmpp.plugins.xep_0004
from sleekxmpp.xmlstream.matcher.stanzapath import StanzaPath
diff --git a/sleekxmpp/__init__.py b/sleekxmpp/__init__.py
index 263f1f99..b62e35db 100644
--- a/sleekxmpp/__init__.py
+++ b/sleekxmpp/__init__.py
@@ -1,13 +1,12 @@
#!/usr/bin/python2.5
"""
- SleekXMPP: The Sleek XMPP Library
- Copyright (C) 2010 Nathanael C. Fritz
- This file is part of SleekXMPP.
+ SleekXMPP: The Sleek XMPP Library
+ Copyright (C) 2010 Nathanael C. Fritz
+ This file is part of SleekXMPP.
- See the file license.txt for copying permission.
+ See the file license.txt for copying permission.
"""
-from __future__ import absolute_import, unicode_literals
from . basexmpp import basexmpp
from xml.etree import cElementTree as ET
from . xmlstream.xmlstream import XMLStream
@@ -27,10 +26,15 @@ import sys
import random
import copy
from . import plugins
+from xml.etree.cElementTree import tostring
+from xml.etree.cElementTree import Element
+from cStringIO import StringIO
+
#from . import stanza
srvsupport = True
try:
import dns.resolver
+ import dns.rdatatype
except ImportError:
srvsupport = False
@@ -53,12 +57,14 @@ class ClientXMPP(basexmpp, XMLStream):
self.plugin_config = plugin_config
self.escape_quotes = escape_quotes
self.set_jid(jid)
+ self.server = None
+ self.port = 5222 # not used if DNS SRV is used
self.plugin_whitelist = plugin_whitelist
self.auto_reconnect = True
self.srvsupport = srvsupport
self.password = password
self.registered_features = []
- self.stream_header = """<stream:stream to='%s' xmlns:stream='http://etherx.jabber.org/streams' xmlns='%s' version='1.0'>""" % (self.server,self.default_ns)
+ self.stream_header = """<stream:stream to='%s' xmlns:stream='http://etherx.jabber.org/streams' xmlns='%s' version='1.0'>""" % (self.domain,self.default_ns)
self.stream_footer = "</stream:stream>"
#self.map_namespace('http://etherx.jabber.org/streams', 'stream')
#self.map_namespace('jabber:client', '')
@@ -66,8 +72,16 @@ class ClientXMPP(basexmpp, XMLStream):
#TODO: Use stream state here
self.authenticated = False
self.sessionstarted = False
- self.registerHandler(Callback('Stream Features', MatchXPath('{http://etherx.jabber.org/streams}features'), self._handleStreamFeatures, thread=True))
- self.registerHandler(Callback('Roster Update', MatchXPath('{%s}iq/{jabber:iq:roster}query' % self.default_ns), self._handleRoster, thread=True))
+ self.bound = False
+ self.bindfail = False
+ self.digest_auth_started = False
+ XMLStream.registerHandler(self, Callback('Stream Features', MatchXPath('{http://etherx.jabber.org/streams}features'), self._handleStreamFeatures, thread=True))
+ XMLStream.registerHandler(self, Callback('Roster Update', MatchXPath('{%s}iq/{jabber:iq:roster}query' % self.default_ns), self._handleRoster, thread=True))
+ #SASL Auth handlers
+ basexmpp.add_handler(self, "<challenge xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_sasl_digest_md5_auth, instream=True)
+ basexmpp.add_handler(self, "<response xmlns='urn:ietf:params:xml:ns:xmpp-sasl'/>", self.handler_sasl_digest_md5_auth_fail, instream=True)
+ basexmpp.add_handler(self, "<success xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_auth_success, instream=True)
+ basexmpp.add_handler(self, "<failure xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_auth_fail, instream=True)
#self.registerHandler(Callback('Roster Update', MatchXMLMask("<presence xmlns='%s' type='subscribe' />" % self.default_ns), self._handlePresenceSubscribe, thread=True))
self.registerFeature("<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls' />", self.handler_starttls, True)
self.registerFeature("<mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_sasl_auth, True)
@@ -87,12 +101,18 @@ class ClientXMPP(basexmpp, XMLStream):
def get(self, key, default):
return self.plugin.get(key, default)
- def connect(self, address=tuple()):
+ def connect(self, host=None, port=None):
"""Connect to the Jabber Server. Attempts SRV lookup, and if it fails, uses
the JID server."""
- if not address or len(address) < 2:
+
+ if self.state['connected']: return True
+
+ if host:
+ self.server = host
+ if port is None: port = self.port
+ else:
if not self.srvsupport:
- logging.debug("Did not supply (address, port) to connect to and no SRV support is installed (http://www.dnspython.org). Continuing to attempt connection, using server hostname from JID.")
+ logging.debug("Did not supply (address, port) to connect to and no SRV support is installed (http://www.dnspython.org). Continuing to attempt connection, using domain from JID.")
else:
logging.debug("Since no address is supplied, attempting SRV lookup.")
try:
@@ -113,12 +133,19 @@ class ClientXMPP(basexmpp, XMLStream):
picked = random.randint(0, intmax)
for priority in priorities:
if picked <= priority:
- address = addresses[priority]
+ (host,port) = addresses[priority]
break
- if not address:
+ # if SRV lookup was successful, we aren't using a particular server.
+ self.server = None
+
+ if not host:
# if all else fails take server from JID.
- address = (self.server, 5222)
- result = XMLStream.connect(self, address[0], address[1], use_tls=True)
+ (host,port) = (self.domain, self.port)
+ self.server = None
+
+ logging.debug('Attempting connection to %s:%d', host, port )
+ #TODO option to not use TLS?
+ result = XMLStream.connect(self, host, port, use_tls=True)
if result:
self.event("connected")
else:
@@ -129,12 +156,12 @@ class ClientXMPP(basexmpp, XMLStream):
# overriding reconnect and disconnect so that we can get some events
# should events be part of or required by xmlstream? Maybe that would be cleaner
def reconnect(self):
- logging.info("Reconnecting")
- self.event("disconnected")
- XMLStream.reconnect(self)
+ self.disconnect(reconnect=True)
- def disconnect(self, init=True, close=False, reconnect=False):
+ def disconnect(self, reconnect=False):
self.event("disconnected")
+ self.authenticated = False
+ self.sessionstarted = False
XMLStream.disconnect(self, reconnect)
def registerFeature(self, mask, pointer, breaker = False):
@@ -155,6 +182,7 @@ class ClientXMPP(basexmpp, XMLStream):
self._handleRoster(iq, request=True)
def _handleStreamFeatures(self, features):
+ logging.debug('handling stream features')
self.features = []
for sub in features.xml:
self.features.append(sub.tag)
@@ -162,13 +190,17 @@ class ClientXMPP(basexmpp, XMLStream):
for feature in self.registered_features:
if feature[0].match(subelement):
#if self.maskcmp(subelement, feature[0], True):
+ # This calls the feature handler & optionally breaks
if feature[1](subelement) and feature[2]: #if breaker, don't continue
return True
def handler_starttls(self, xml):
+ logging.debug( 'TLS start handler; SSL support: %s', self.ssl_support )
if not self.authenticated and self.ssl_support:
- self.add_handler("<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls' />", self.handler_tls_start, instream=True)
- self.sendXML(xml)
+ _stanza = "<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls' />"
+ if not self.event_handlers.get(_stanza,None): # don't add handler > once
+ self.add_handler( _stanza, self.handler_tls_start, instream=True )
+ self.sendPriorityRaw(self.tostring(xml))
return True
else:
logging.warning("The module tlslite is required in to some servers, and has not been found.")
@@ -183,17 +215,17 @@ class ClientXMPP(basexmpp, XMLStream):
if '{urn:ietf:params:xml:ns:xmpp-tls}starttls' in self.features:
return False
logging.debug("Starting SASL Auth")
- self.add_handler("<success xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_auth_success, instream=True)
- self.add_handler("<failure xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_auth_fail, instream=True)
sasl_mechs = xml.findall('{urn:ietf:params:xml:ns:xmpp-sasl}mechanism')
if len(sasl_mechs):
for sasl_mech in sasl_mechs:
self.features.append("sasl:%s" % sasl_mech.text)
- if 'sasl:PLAIN' in self.features:
+ if 'sasl:DIGEST-MD5' in self.features:
+ self.sendPriorityRaw("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>""")
+ elif 'sasl:PLAIN' in self.features:
if sys.version_info < (3,0):
- self.send("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='PLAIN'>%s</auth>""" % base64.b64encode(b'\x00' + bytes(self.username) + b'\x00' + bytes(self.password)).decode('utf-8'))
+ self.sendPriorityRaw("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='PLAIN'>%s</auth>""" % base64.b64encode(b'\x00' + bytes(self.username) + b'\x00' + bytes(self.password)).decode('utf-8'))
else:
- self.send("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='PLAIN'>%s</auth>""" % base64.b64encode(b'\x00' + bytes(self.username, 'utf-8') + b'\x00' + bytes(self.password, 'utf-8')).decode('utf-8'))
+ self.sendPriorityRaw("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='PLAIN'>%s</auth>""" % base64.b64encode(b'\x00' + bytes(self.username, 'utf-8') + b'\x00' + bytes(self.password, 'utf-8')).decode('utf-8'))
else:
logging.error("No appropriate login method.")
self.disconnect()
@@ -201,13 +233,50 @@ class ClientXMPP(basexmpp, XMLStream):
# self._auth_digestmd5()
return True
+ def handler_sasl_digest_md5_auth(self, xml):
+ logging.debug(tostring(xml))
+ logging.debug(xml)
+ logging.debug(type(xml).__name__)
+
+ if self.digest_auth_started == False:
+ challenge = [item.split('=', 1) for item in base64.b64decode(xml.text).replace("\"", "").split(',', 6) ]
+ challenge = dict(challenge)
+ logging.debug(challenge)
+
+ #Realm, nonce, qop should all be present
+ if not challenge['realm'] or not challenge['qop'] or not challenge['nonce']:
+ logging.error("Error during digest-md5 authentication. Challenge missing critical information. Challenge: %s" %base64.b64decode(xml.text))
+ self.disconnect()
+ self.event("failed_auth")
+ return
+ #TODO: charset can be either UTF-8 or if not present use ISO 8859-1 defaulting for UTF-8 for now
+ #Compute the cnonce - a unique hex string only used in this request
+ cnonce = ""
+ for i in range(7):
+ cnonce+=hex(int(random.random()*65536*4096))[2:]
+ cnonce = base64.encodestring(cnonce)[0:-1]
+ a1 = md5("%s:%s:%s" % (self.username, self.domain, self.password))
+ a1 = "%s:%s:%s" %(a1, challenge["nonce"], cnonce )
+ a2 = "AUTHENTICATE:xmpp/%s" %self.domain
+ responseHash = md5digest("%s:%s:00000001:%s:auth:%s" %(md5digest(a1), challenge["nonce"], cnonce, md5digest(a2) ) )
+ response = '''charset=utf-8,username="%s",realm="%s",nonce="%s",nc=00000001,cnonce="%s",digest-uri="%s",response=%s,qop=%s,''' %(self.username, self.domain, challenge["nonce"], cnonce, "xmpp/%s" % self.domain, responseHash, challenge["qop"])
+ self.sendPriorityRaw("""<response xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>%s</response>""" %base64.encodestring(response)[:-1])
+ else:
+ logging.warn("handler_sasl_digest_md5_auth called while digest_auth_started is True (has already begun)")
+
+ def handler_sasl_digest_md5_auth_fail(self, xml):
+ self.digest_auth_started = False
+ self.handler_auth_fail(xml)
+
def handler_auth_success(self, xml):
+ logging.debug("Authentication successful.")
self.authenticated = True
self.features = []
raise RestartStream()
def handler_auth_fail(self, xml):
- logging.info("Authentication failed.")
+ logging.warning("Authentication failed.")
+ logging.debug(tostring(xml, 'utf-8'))
self.disconnect()
self.event("failed_auth")
@@ -221,19 +290,23 @@ class ClientXMPP(basexmpp, XMLStream):
response = iq.send()
#response = self.send(iq, self.Iq(sid=iq['id']))
self.set_jid(response.xml.find('{urn:ietf:params:xml:ns:xmpp-bind}bind/{urn:ietf:params:xml:ns:xmpp-bind}jid').text)
+ self.bound = True
logging.info("Node set to: %s" % self.fulljid)
- if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features:
+ if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features or self.bindfail:
logging.debug("Established Session")
self.sessionstarted = True
self.event("session_start")
def handler_start_session(self, xml):
- if self.authenticated:
+ if self.authenticated and self.bound:
iq = self.makeIqSet(xml)
response = iq.send()
logging.debug("Established Session")
self.sessionstarted = True
self.event("session_start")
+ else:
+ #bind probably hasn't happened yet
+ self.bindfail = True
def _handleRoster(self, iq, request=False):
if iq['type'] == 'set' or (iq['type'] == 'result' and request):
@@ -244,3 +317,21 @@ class ClientXMPP(basexmpp, XMLStream):
if iq['type'] == 'set':
self.send(self.Iq().setValues({'type': 'result', 'id': iq['id']}).enable('roster'))
self.event("roster_update", iq)
+
+def md5(data):
+ try:
+ import hashlib
+ md5 = hashlib.md5(data)
+ except ImportError:
+ import md5
+ md5 = md5.new(data)
+ return md5.digest()
+
+def md5digest(data):
+ try:
+ import hashlib
+ md5 = hashlib.md5(data)
+ except ImportError:
+ import md5
+ md5 = md5.new(data)
+ return md5.hexdigest()
diff --git a/sleekxmpp/basexmpp.py b/sleekxmpp/basexmpp.py
index 907067fa..936b5d18 100644
--- a/sleekxmpp/basexmpp.py
+++ b/sleekxmpp/basexmpp.py
@@ -1,9 +1,9 @@
"""
- SleekXMPP: The Sleek XMPP Library
- Copyright (C) 2010 Nathanael C. Fritz
- This file is part of SleekXMPP.
+ SleekXMPP: The Sleek XMPP Library
+ Copyright (C) 2010 Nathanael C. Fritz
+ This file is part of SleekXMPP.
- See the file license.txt for copying permission.
+ See the file license.txt for copying permission.
"""
from __future__ import with_statement, unicode_literals
@@ -49,7 +49,7 @@ class basexmpp(object):
self.resource = ''
self.jid = ''
self.username = ''
- self.server = ''
+ self.domain = ''
self.plugin = {}
self.auto_authorize = True
self.auto_subscribe = True
@@ -84,28 +84,35 @@ class basexmpp(object):
self.resource = self.getjidresource(jid)
self.jid = self.getjidbare(jid)
self.username = jid.split('@', 1)[0]
- self.server = jid.split('@',1)[-1].split('/', 1)[0]
+ self.domain = jid.split('@',1)[-1].split('/', 1)[0]
def process(self, *args, **kwargs):
for idx in self.plugin:
if not self.plugin[idx].post_inited: self.plugin[idx].post_init()
return super(basexmpp, self).process(*args, **kwargs)
- def registerPlugin(self, plugin, pconfig = {}):
+ def registerPlugin(self, plugin, pconfig = {}, pluginModule = None):
"""Register a plugin not in plugins.__init__.__all__ but in the plugins
directory."""
# discover relative "path" to the plugins module from the main app, and import it.
# TODO:
# gross, this probably isn't necessary anymore, especially for an installed module
- __import__("%s.%s" % (globals()['plugins'].__name__, plugin))
- # init the plugin class
- self.plugin[plugin] = getattr(getattr(plugins, plugin), plugin)(self, pconfig) # eek
- # all of this for a nice debug? sure.
- xep = ''
- if hasattr(self.plugin[plugin], 'xep'):
- xep = "(XEP-%s) " % self.plugin[plugin].xep
- logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description))
-
+ try:
+ if pluginModule:
+ module = __import__(pluginModule, globals(), locals(), [plugin])
+ else:
+ module = __import__("%s.%s" % (globals()['plugins'].__name__, plugin), globals(), locals(), [plugin])
+ # init the plugin class
+ self.plugin[plugin] = getattr(module, plugin)(self, pconfig) # eek
+ # all of this for a nice debug? sure.
+ xep = ''
+ if hasattr(self.plugin[plugin], 'xep'):
+ xep = "(XEP-%s) " % self.plugin[plugin].xep
+ logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description))
+ except Exception, e:
+ logging.error("Unable to load plugin: %s" %(plugin) )
+ logging.exception(e)
+
def register_plugins(self):
"""Initiates all plugins in the plugins/__init__.__all__"""
if self.plugin_whitelist:
diff --git a/sleekxmpp/plugins/jobs.py b/sleekxmpp/plugins/jobs.py
new file mode 100644
index 00000000..bb2e2554
--- /dev/null
+++ b/sleekxmpp/plugins/jobs.py
@@ -0,0 +1,44 @@
+from . import base
+import logging
+from xml.etree import cElementTree as ET
+
+class jobs(base.base_plugin):
+ def plugin_init(self):
+ self.xep = 'pubsubjob'
+ self.description = "Job distribution over Pubsub"
+
+ def post_init(self):
+ pass
+ #TODO add event
+
+ def createJobNode(self, host, jid, node, config=None):
+ pass
+
+ def createJob(self, host, node, jobid=None, payload=None):
+ return self.xmpp.plugin['xep_0060'].setItem(host, node, ((jobid, payload),))
+
+ def claimJob(self, host, node, jobid, ifrom=None):
+ return self._setState(host, node, jobid, ET.Element('{http://andyet.net/protocol/pubsubjob}claimed'))
+
+ def unclaimJob(self, jobid):
+ return self._setState(host, node, jobid, ET.Element('{http://andyet.net/protocol/pubsubjob}unclaimed'))
+
+ def finishJob(self, host, node, jobid, payload=None):
+ finished = ET.Element('{http://andyet.net/protocol/pubsubjob}finished')
+ if payload is not None:
+ finished.append(payload)
+ return self._setState(host, node, jobid, finished)
+
+ def _setState(self, host, node, jobid, state, ifrom=None):
+ iq = self.xmpp.Iq()
+ iq['to'] = host
+ if ifrom: iq['from'] = ifrom
+ iq['type'] = 'set'
+ iq['psstate']['node'] = node
+ iq['psstate']['item'] = jobid
+ iq['psstate']['payload'] = state
+ result = iq.send()
+ if result is None or result['type'] != 'result':
+ return False
+ return True
+
diff --git a/sleekxmpp/plugins/stanza_pubsub.py b/sleekxmpp/plugins/stanza_pubsub.py
index 1dd73d99..1a1526f0 100644
--- a/sleekxmpp/plugins/stanza_pubsub.py
+++ b/sleekxmpp/plugins/stanza_pubsub.py
@@ -10,6 +10,39 @@ def stanzaPlugin(stanza, plugin):
stanza.plugin_attrib_map[plugin.plugin_attrib] = plugin
stanza.plugin_tag_map["{%s}%s" % (plugin.namespace, plugin.name)] = plugin
+class PubsubState(ElementBase):
+ namespace = 'http://jabber.org/protocol/psstate'
+ name = 'state'
+ plugin_attrib = 'psstate'
+ interfaces = set(('node', 'item', 'payload'))
+ plugin_attrib_map = {}
+ plugin_tag_map = {}
+
+ def setPayload(self, value):
+ self.xml.append(value)
+
+ def getPayload(self):
+ childs = self.xml.getchildren()
+ if len(childs) > 0:
+ return childs[0]
+
+ def delPayload(self):
+ for child in self.xml.getchildren():
+ self.xml.remove(child)
+
+stanzaPlugin(Iq, PubsubState)
+
+class PubsubStateEvent(ElementBase):
+ namespace = 'http://jabber.org/protocol/psstate#event'
+ name = 'event'
+ plugin_attrib = 'psstate_event'
+ intefaces = set(tuple())
+ plugin_attrib_map = {}
+ plugin_tag_map = {}
+
+stanzaPlugin(Message, PubsubStateEvent)
+stanzaPlugin(PubsubStateEvent, PubsubState)
+
class Pubsub(ElementBase):
namespace = 'http://jabber.org/protocol/pubsub'
name = 'pubsub'
@@ -321,18 +354,6 @@ class Options(ElementBase):
stanzaPlugin(Pubsub, Options)
stanzaPlugin(Subscribe, Options)
-#iq = Iq()
-#iq['pubsub']['defaultconfig']
-#print(iq)
-
-#from xml.etree import cElementTree as ET
-#iq = Iq()
-#item = Item()
-#item['payload'] = ET.Element("{http://netflint.net/p/crap}stupidshit")
-#item['id'] = 'aa11bbcc'
-#iq['pubsub']['items'].append(item)
-#print(iq)
-
class OwnerAffiliations(Affiliations):
namespace = 'http://jabber.org/protocol/pubsub#owner'
interfaces = set(('node'))
diff --git a/sleekxmpp/plugins/xep_0004.py b/sleekxmpp/plugins/xep_0004.py
index 56d18929..015bd8bc 100644
--- a/sleekxmpp/plugins/xep_0004.py
+++ b/sleekxmpp/plugins/xep_0004.py
@@ -188,7 +188,6 @@ class Form(FieldContainer):
#def getXML(self, tostring = False):
def getXML(self, ftype=None):
- logging.debug("creating form as %s" % ftype)
if ftype:
self.type = ftype
form = ET.Element('{jabber:x:data}x')
diff --git a/sleekxmpp/plugins/xep_0030.py b/sleekxmpp/plugins/xep_0030.py
index 5432dd56..6a31d243 100644
--- a/sleekxmpp/plugins/xep_0030.py
+++ b/sleekxmpp/plugins/xep_0030.py
@@ -1,25 +1,184 @@
"""
- SleekXMPP: The Sleek XMPP Library
- Copyright (C) 2007 Nathanael C. Fritz
- This file is part of SleekXMPP.
-
- SleekXMPP is free software; you can redistribute it and/or modify
- it under the terms of the GNU General Public License as published by
- the Free Software Foundation; either version 2 of the License, or
- (at your option) any later version.
-
- SleekXMPP is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU General Public License for more details.
-
- You should have received a copy of the GNU General Public License
- along with SleekXMPP; if not, write to the Free Software
- Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
+ SleekXMPP: The Sleek XMPP Library
+ Copyright (C) 2010 Nathanael C. Fritz, Lance J.T. Stout
+ This file is part of SleekXMPP.
+
+ See the file license.txt for copying permissio
"""
-from . import base
+
import logging
-from xml.etree import cElementTree as ET
+from . import base
+from .. xmlstream.handler.callback import Callback
+from .. xmlstream.matcher.xpath import MatchXPath
+from .. xmlstream.stanzabase import ElementBase, ET, JID
+from .. stanza.iq import Iq
+
+class DiscoInfo(ElementBase):
+ namespace = 'http://jabber.org/protocol/disco#info'
+ name = 'query'
+ plugin_attrib = 'disco_info'
+ interfaces = set(('node', 'features', 'identities'))
+
+ def getFeatures(self):
+ features = []
+ featuresXML = self.xml.findall('{%s}feature' % self.namespace)
+ for feature in featuresXML:
+ features.append(feature.attrib['var'])
+ return features
+
+ def setFeatures(self, features):
+ self.delFeatures()
+ for name in features:
+ self.addFeature(name)
+
+ def delFeatures(self):
+ featuresXML = self.xml.findall('{%s}feature' % self.namespace)
+ for feature in featuresXML:
+ self.xml.remove(feature)
+
+ def addFeature(self, feature):
+ featureXML = ET.Element('{%s}feature' % self.namespace,
+ {'var': feature})
+ self.xml.append(featureXML)
+
+ def delFeature(self, feature):
+ featuresXML = self.xml.findall('{%s}feature' % self.namespace)
+ for featureXML in featuresXML:
+ if featureXML.attrib['var'] == feature:
+ self.xml.remove(featureXML)
+
+ def getIdentities(self):
+ ids = []
+ idsXML = self.xml.findall('{%s}identity' % self.namespace)
+ for idXML in idsXML:
+ idData = (idXML.attrib['category'],
+ idXML.attrib['type'],
+ idXML.attrib.get('name', ''))
+ ids.append(idData)
+ return ids
+
+ def setIdentities(self, ids):
+ self.delIdentities()
+ for idData in ids:
+ self.addIdentity(*idData)
+
+ def delIdentities(self):
+ idsXML = self.xml.findall('{%s}identity' % self.namespace)
+ for idXML in idsXML:
+ self.xml.remove(idXML)
+
+ def addIdentity(self, category, id_type, name=''):
+ idXML = ET.Element('{%s}identity' % self.namespace,
+ {'category': category,
+ 'type': id_type,
+ 'name': name})
+ self.xml.append(idXML)
+
+ def delIdentity(self, category, id_type, name=''):
+ idsXML = self.xml.findall('{%s}identity' % self.namespace)
+ for idXML in idsXML:
+ idData = (idXML.attrib['category'],
+ idXML.attrib['type'])
+ delId = (category, id_type)
+ if idData == delId:
+ self.xml.remove(idXML)
+
+
+class DiscoItems(ElementBase):
+ namespace = 'http://jabber.org/protocol/disco#items'
+ name = 'query'
+ plugin_attrib = 'disco_items'
+ interfaces = set(('node', 'items'))
+
+ def getItems(self):
+ items = []
+ itemsXML = self.xml.findall('{%s}item' % self.namespace)
+ for item in itemsXML:
+ itemData = (item.attrib['jid'],
+ item.attrib.get('node'),
+ item.attrib.get('name'))
+ items.append(itemData)
+ return items
+
+ def setItems(self, items):
+ self.delItems()
+ for item in items:
+ self.addItem(*item)
+
+ def delItems(self):
+ itemsXML = self.xml.findall('{%s}item' % self.namespace)
+ for item in itemsXML:
+ self.xml.remove(item)
+
+ def addItem(self, jid, node='', name=''):
+ itemXML = ET.Element('{%s}item' % self.namespace, {'jid': jid})
+ if name:
+ itemXML.attrib['name'] = name
+ if node:
+ itemXML.attrib['node'] = node
+ self.xml.append(itemXML)
+
+ def delItem(self, jid, node=''):
+ itemsXML = self.xml.findall('{%s}item' % self.namespace)
+ for itemXML in itemsXML:
+ itemData = (itemXML.attrib['jid'],
+ itemXML.attrib.get('node', ''))
+ itemDel = (jid, node)
+ if itemData == itemDel:
+ self.xml.remove(itemXML)
+
+
+class DiscoNode(object):
+ """
+ Collection object for grouping info and item information
+ into nodes.
+ """
+ def __init__(self, name):
+ self.name = name
+ self.info = DiscoInfo()
+ self.items = DiscoItems()
+
+ # This is a bit like poor man's inheritance, but
+ # to simplify adding information to the node we
+ # map node functions to either the info or items
+ # stanza objects.
+ #
+ # We don't want to make DiscoNode inherit from
+ # DiscoInfo and DiscoItems because DiscoNode is
+ # not an actual stanza, and doing so would create
+ # confusion and potential bugs.
+
+ self._map(self.items, 'items', ['get', 'set', 'del'])
+ self._map(self.items, 'item', ['add', 'del'])
+ self._map(self.info, 'identities', ['get', 'set', 'del'])
+ self._map(self.info, 'identity', ['add', 'del'])
+ self._map(self.info, 'features', ['get', 'set', 'del'])
+ self._map(self.info, 'feature', ['add', 'del'])
+
+ def isEmpty(self):
+ """
+ Test if the node contains any information. Useful for
+ determining if a node can be deleted.
+ """
+ ids = self.getIdentities()
+ features = self.getFeatures()
+ items = self.getItems()
+
+ if not ids and not features and not items:
+ return True
+ return False
+
+ def _map(self, obj, interface, access):
+ """
+ Map functions of the form obj.accessInterface
+ to self.accessInterface for each given access type.
+ """
+ interface = interface.title()
+ for access_type in access:
+ method = access_type + interface
+ if hasattr(obj, method):
+ setattr(self, method, getattr(obj, method))
+
class xep_0030(base.base_plugin):
"""
@@ -29,85 +188,137 @@ class xep_0030(base.base_plugin):
def plugin_init(self):
self.xep = '0030'
self.description = 'Service Discovery'
- self.features = {'main': ['http://jabber.org/protocol/disco#info', 'http://jabber.org/protocol/disco#items']}
- self.identities = {'main': [{'category': 'client', 'type': 'pc', 'name': 'SleekXMPP'}]}
- self.items = {'main': []}
- self.xmpp.add_handler("<iq type='get' xmlns='%s'><query xmlns='http://jabber.org/protocol/disco#info' /></iq>" % self.xmpp.default_ns, self.info_handler)
- self.xmpp.add_handler("<iq type='get' xmlns='%s'><query xmlns='http://jabber.org/protocol/disco#items' /></iq>" % self.xmpp.default_ns, self.item_handler)
+
+ self.xmpp.registerHandler(
+ Callback('Disco Items',
+ MatchXPath('{%s}iq/{%s}query' % (self.xmpp.default_ns,
+ DiscoItems.namespace)),
+ self.handle_item_query))
+
+ self.xmpp.registerHandler(
+ Callback('Disco Info',
+ MatchXPath('{%s}iq/{%s}query' % (self.xmpp.default_ns,
+ DiscoInfo.namespace)),
+ self.handle_info_query))
+
+ self.xmpp.stanzaPlugin(Iq, DiscoInfo)
+ self.xmpp.stanzaPlugin(Iq, DiscoItems)
+
+ self.xmpp.add_event_handler('disco_items_request', self.handle_disco_items)
+ self.xmpp.add_event_handler('disco_info_request', self.handle_disco_info)
+
+ self.nodes = {'main': DiscoNode('main')}
+
+ def add_node(self, node):
+ if node not in self.nodes:
+ self.nodes[node] = DiscoNode(node)
+
+ def del_node(self, node):
+ if node in self.nodes:
+ del self.nodes[node]
+
+ def handle_item_query(self, iq):
+ if iq['type'] == 'get':
+ logging.debug("Items requested by %s" % iq['from'])
+ self.xmpp.event('disco_items_request', iq)
+ elif iq['type'] == 'result':
+ logging.debug("Items result from %s" % iq['from'])
+ self.xmpp.event('disco_items', iq)
+
+ def handle_info_query(self, iq):
+ if iq['type'] == 'get':
+ logging.debug("Info requested by %s" % iq['from'])
+ self.xmpp.event('disco_info_request', iq)
+ elif iq['type'] == 'result':
+ logging.debug("Info result from %s" % iq['from'])
+ self.xmpp.event('disco_info', iq)
+
+ def handle_disco_info(self, iq, forwarded=False):
+ """
+ A default handler for disco#info requests. If another
+ handler is registered, this one will defer and not run.
+ """
+ handlers = self.xmpp.event_handlers['disco_info_request']
+ if not forwarded and len(handlers) > 1:
+ return
+
+ node_name = iq['disco_info']['node']
+ if not node_name:
+ node_name = 'main'
+
+ logging.debug("Using default handler for disco#info on node '%s'." % node_name)
+
+ if node_name in self.nodes:
+ node = self.nodes[node_name]
+ iq.reply().setPayload(node.info.xml).send()
+ else:
+ logging.debug("Node %s requested, but does not exist." % node_name)
+ iq.reply().error().setPayload(iq['disco_info'].xml)
+ iq['error']['code'] = '404'
+ iq['error']['type'] = 'cancel'
+ iq['error']['condition'] = 'item-not-found'
+ iq.send()
+
+ def handle_disco_items(self, iq, forwarded=False):
+ """
+ A default handler for disco#items requests. If another
+ handler is registered, this one will defer and not run.
+
+ If this handler is called by your own custom handler with
+ forwarded set to True, then it will run as normal.
+ """
+ handlers = self.xmpp.event_handlers['disco_items_request']
+ if not forwarded and len(handlers) > 1:
+ return
+
+ node_name = iq['disco_items']['node']
+ if not node_name:
+ node_name = 'main'
+
+ logging.debug("Using default handler for disco#items on node '%s'." % node_name)
+
+ if node_name in self.nodes:
+ node = self.nodes[node_name]
+ iq.reply().setPayload(node.items.xml).send()
+ else:
+ logging.debug("Node %s requested, but does not exist." % node_name)
+ iq.reply().error().setPayload(iq['disco_items'].xml)
+ iq['error']['code'] = '404'
+ iq['error']['type'] = 'cancel'
+ iq['error']['condition'] = 'item-not-found'
+ iq.send()
+
+ # Older interface methods for backwards compatibility
+
+ def getInfo(self, jid, node=''):
+ iq = self.xmpp.Iq()
+ iq['type'] = 'get'
+ iq['to'] = jid
+ iq['from'] = self.xmpp.fulljid
+ iq['disco_info']['node'] = node
+ iq.send()
+
+ def getItems(self, jid, node=''):
+ iq = self.xmpp.Iq()
+ iq['type'] = 'get'
+ iq['to'] = jid
+ iq['from'] = self.xmpp.fulljid
+ iq['disco_items']['node'] = node
+ iq.send()
def add_feature(self, feature, node='main'):
- if not node in self.features:
- self.features[node] = []
- self.features[node].append(feature)
-
- def add_identity(self, category=None, itype=None, name=None, node='main'):
- if not node in self.identities:
- self.identities[node] = []
- self.identities[node].append({'category': category, 'type': itype, 'name': name})
+ self.add_node(node)
+ self.nodes[node].addFeature(feature)
- def add_item(self, jid=None, name=None, node='main', subnode=''):
- if not node in self.items:
- self.items[node] = []
- self.items[node].append({'jid': jid, 'name': name, 'node': subnode})
-
- def info_handler(self, xml):
- logging.debug("Info request from %s" % xml.get('from', ''))
- iq = self.xmpp.makeIqResult(xml.get('id', self.xmpp.getNewId()))
- iq.attrib['from'] = xml.get('to')
- iq.attrib['to'] = xml.get('from', self.xmpp.server)
- query = xml.find('{http://jabber.org/protocol/disco#info}query')
- node = query.get('node', 'main')
- for identity in self.identities.get(node, []):
- idxml = ET.Element('identity')
- for attrib in identity:
- if identity[attrib]:
- idxml.attrib[attrib] = identity[attrib]
- query.append(idxml)
- for feature in self.features.get(node, []):
- featxml = ET.Element('feature')
- featxml.attrib['var'] = feature
- query.append(featxml)
- iq.append(query)
- #print ET.tostring(iq)
- self.xmpp.send(iq)
-
- def item_handler(self, xml):
- logging.debug("Item request from %s" % xml.get('from', ''))
- iq = self.xmpp.makeIqResult(xml.get('id', self.xmpp.getNewId()))
- iq.attrib['from'] = xml.get('to')
- iq.attrib['to'] = xml.get('from', self.xmpp.server)
- query = self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#items').find('{http://jabber.org/protocol/disco#items}query')
- node = xml.find('{http://jabber.org/protocol/disco#items}query').get('node', 'main')
- for item in self.items.get(node, []):
- itemxml = ET.Element('item')
- itemxml.attrib = item
- if itemxml.attrib['jid'] is None:
- itemxml.attrib['jid'] = xml.get('to')
- query.append(itemxml)
- self.xmpp.send(iq)
+ def add_identity(self, category='', itype='', name='', node='main'):
+ self.add_node(node)
+ self.nodes[node].addIdentity(category=category,
+ id_type=itype,
+ name=name)
- def getItems(self, jid, node=None):
- iq = self.xmpp.makeIqGet()
- iq.attrib['from'] = self.xmpp.fulljid
- iq.attrib['to'] = jid
- self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#items')
- if node:
- iq.find('{http://jabber.org/protocol/disco#items}query').attrib['node'] = node
- return iq.send()
-
- def getInfo(self, jid, node=None):
- iq = self.xmpp.makeIqGet()
- iq.attrib['from'] = self.xmpp.fulljid
- iq.attrib['to'] = jid
- self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#info')
- if node:
- iq.find('{http://jabber.org/protocol/disco#info}query').attrib['node'] = node
- return iq.send()
-
- def parseInfo(self, xml):
- result = {'identity': {}, 'feature': []}
- for identity in xml.findall('{http://jabber.org/protocol/disco#info}query/{{http://jabber.org/protocol/disco#info}identity'):
- result['identity'][identity['name']] = identity.attrib
- for feature in xml.findall('{http://jabber.org/protocol/disco#info}query/{{http://jabber.org/protocol/disco#info}feature'):
- result['feature'].append(feature.get('var', '__unknown__'))
- return result
+ def add_item(self, jid=None, name='', node='main', subnode=''):
+ self.add_node(node)
+ self.add_node(subnode)
+ if jid is None:
+ jid = self.xmpp.fulljid
+ self.nodes[node].addItem(jid=jid, name=name, node=subnode)
diff --git a/sleekxmpp/plugins/xep_0060.py b/sleekxmpp/plugins/xep_0060.py
index 44a70e9a..bff158a0 100644
--- a/sleekxmpp/plugins/xep_0060.py
+++ b/sleekxmpp/plugins/xep_0060.py
@@ -14,12 +14,14 @@ class xep_0060(base.base_plugin):
self.xep = '0060'
self.description = 'Publish-Subscribe'
- def create_node(self, jid, node, config=None, collection=False):
+ def create_node(self, jid, node, config=None, collection=False, ntype=None):
pubsub = ET.Element('{http://jabber.org/protocol/pubsub}pubsub')
create = ET.Element('create')
create.set('node', node)
pubsub.append(create)
configure = ET.Element('configure')
+ if collection:
+ ntype = 'collection'
#if config is None:
# submitform = self.xmpp.plugin['xep_0004'].makeForm('submit')
#else:
@@ -29,11 +31,11 @@ class xep_0060(base.base_plugin):
submitform.field['FORM_TYPE'].setValue('http://jabber.org/protocol/pubsub#node_config')
else:
submitform.addField('FORM_TYPE', 'hidden', value='http://jabber.org/protocol/pubsub#node_config')
- if collection:
+ if ntype:
if 'pubsub#node_type' in submitform.field:
- submitform.field['pubsub#node_type'].setValue('collection')
+ submitform.field['pubsub#node_type'].setValue(ntype)
else:
- submitform.addField('pubsub#node_type', value='collection')
+ submitform.addField('pubsub#node_type', value=ntype)
else:
if 'pubsub#node_type' in submitform.field:
submitform.field['pubsub#node_type'].setValue('leaf')
diff --git a/sleekxmpp/stanza/error.py b/sleekxmpp/stanza/error.py
index f87b6490..3346ceb2 100644
--- a/sleekxmpp/stanza/error.py
+++ b/sleekxmpp/stanza/error.py
@@ -1,9 +1,9 @@
"""
- SleekXMPP: The Sleek XMPP Library
- Copyright (C) 2010 Nathanael C. Fritz
- This file is part of SleekXMPP.
+ SleekXMPP: The Sleek XMPP Library
+ Copyright (C) 2010 Nathanael C. Fritz
+ This file is part of SleekXMPP.
- See the file license.txt for copying permission.
+ See the file license.txt for copying permission.
"""
from .. xmlstream.stanzabase import ElementBase, ET
@@ -11,7 +11,7 @@ class Error(ElementBase):
namespace = 'jabber:client'
name = 'error'
plugin_attrib = 'error'
- conditions = set(('bad-request', 'conflict', 'feature-not-implemented', 'forbidden', 'gone', 'item-not-found', 'jid-malformed', 'not-acceptable', 'not-allowed', 'not-authorized', 'payment-required', 'recipient-unavailable', 'redirect', 'registration-required', 'remote-server-not-found', 'remote-server-timeout', 'service-unavailable', 'subscription-required', 'undefined-condition', 'unexpected-request'))
+ conditions = set(('bad-request', 'conflict', 'feature-not-implemented', 'forbidden', 'gone', 'internal-server-error', 'item-not-found', 'jid-malformed', 'not-acceptable', 'not-allowed', 'not-authorized', 'payment-required', 'recipient-unavailable', 'redirect', 'registration-required', 'remote-server-not-found', 'remote-server-timeout', 'resource-constraint', 'service-unavailable', 'subscription-required', 'undefined-condition', 'unexpected-request'))
interfaces = set(('code', 'condition', 'text', 'type'))
types = set(('cancel', 'continue', 'modify', 'auth', 'wait'))
sub_interfaces = set(('text',))
diff --git a/sleekxmpp/stanza/iq.py b/sleekxmpp/stanza/iq.py
index ded7515f..26f09268 100644
--- a/sleekxmpp/stanza/iq.py
+++ b/sleekxmpp/stanza/iq.py
@@ -1,9 +1,9 @@
"""
- SleekXMPP: The Sleek XMPP Library
- Copyright (C) 2010 Nathanael C. Fritz
- This file is part of SleekXMPP.
+ SleekXMPP: The Sleek XMPP Library
+ Copyright (C) 2010 Nathanael C. Fritz
+ This file is part of SleekXMPP.
- See the file license.txt for copying permission.
+ See the file license.txt for copying permission.
"""
from .. xmlstream.stanzabase import StanzaBase
from xml.etree import cElementTree as ET
@@ -67,11 +67,11 @@ class Iq(RootStanza):
self.xml.remove(child)
return self
- def send(self, block=True, timeout=10):
+ def send(self, block=True, timeout=10, priority=False):
if block and self['type'] in ('get', 'set'):
waitfor = Waiter('IqWait_%s' % self['id'], MatcherId(self['id']))
self.stream.registerHandler(waitfor)
- StanzaBase.send(self)
+ StanzaBase.send(self, priority)
return waitfor.wait(timeout)
else:
- return StanzaBase.send(self)
+ return StanzaBase.send(self, priority)
diff --git a/sleekxmpp/xmlstream/handler/base.py b/sleekxmpp/xmlstream/handler/base.py
index 5d55f4ee..a44edf0e 100644
--- a/sleekxmpp/xmlstream/handler/base.py
+++ b/sleekxmpp/xmlstream/handler/base.py
@@ -18,7 +18,7 @@ class BaseHandler(object):
def match(self, xml):
return self._matcher.match(xml)
- def prerun(self, payload):
+ def prerun(self, payload): # what's the point of this if the payload is called again in run??
self._payload = payload
def run(self, payload):
diff --git a/sleekxmpp/xmlstream/handler/callback.py b/sleekxmpp/xmlstream/handler/callback.py
index 49cfa14d..ea5acb5b 100644
--- a/sleekxmpp/xmlstream/handler/callback.py
+++ b/sleekxmpp/xmlstream/handler/callback.py
@@ -17,13 +17,15 @@ class Callback(base.BaseHandler):
self._once = once
self._instream = instream
- def prerun(self, payload):
+ def prerun(self, payload): # prerun actually calls run?!? WTF! Then it gets run AGAIN!
base.BaseHandler.prerun(self, payload)
if self._instream:
+ logging.debug('callback "%s" prerun', self.name)
self.run(payload, True)
def run(self, payload, instream=False):
if not self._instream or instream:
+ logging.debug('callback "%s" run', self.name)
base.BaseHandler.run(self, payload)
#if self._thread:
# x = threading.Thread(name="Callback_%s" % self.name, target=self._pointer, args=(payload,))
diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py
new file mode 100644
index 00000000..40aaf695
--- /dev/null
+++ b/sleekxmpp/xmlstream/scheduler.py
@@ -0,0 +1,87 @@
+try:
+ import queue
+except ImportError:
+ import Queue as queue
+import time
+import threading
+import logging
+
+class Task(object):
+ """Task object for the Scheduler class"""
+ def __init__(self, name, seconds, callback, args=None, kwargs=None, repeat=False, qpointer=None):
+ self.name = name
+ self.seconds = seconds
+ self.callback = callback
+ self.args = args or tuple()
+ self.kwargs = kwargs or {}
+ self.repeat = repeat
+ self.next = time.time() + self.seconds
+ self.qpointer = qpointer
+
+ def run(self):
+ if self.qpointer is not None:
+ self.qpointer.put(('schedule', self.callback, self.args))
+ else:
+ self.callback(*self.args, **self.kwargs)
+ self.reset()
+ return self.repeat
+
+ def reset(self):
+ self.next = time.time() + self.seconds
+
+class Scheduler(object):
+ """Threaded scheduler that allows for updates mid-execution unlike http://docs.python.org/library/sched.html#module-sched"""
+ def __init__(self, parentqueue=None):
+ self.addq = queue.Queue()
+ self.schedule = []
+ self.thread = None
+ self.run = False
+ self.parentqueue = parentqueue
+
+ def process(self, threaded=True):
+ if threaded:
+ self.thread = threading.Thread(name='shedulerprocess', target=self._process)
+ self.thread.start()
+ else:
+ self._process()
+
+ def _process(self):
+ self.run = True
+ while self.run:
+ try:
+ wait = 1
+ updated = False
+ if self.schedule:
+ wait = self.schedule[0].next - time.time()
+ try:
+ if wait <= 0.0:
+ newtask = self.addq.get(False)
+ else:
+ newtask = self.addq.get(True, wait)
+ except queue.Empty:
+ cleanup = []
+ for task in self.schedule:
+ if time.time() >= task.next:
+ updated = True
+ if not task.run():
+ cleanup.append(task)
+ else:
+ break
+ for task in cleanup:
+ x = self.schedule.pop(self.schedule.index(task))
+ else:
+ updated = True
+ self.schedule.append(newtask)
+ finally:
+ if updated: self.schedule = sorted(self.schedule, key=lambda task: task.next)
+ except KeyboardInterrupt:
+ self.run = False
+ logging.debug("Quitting Scheduler thread")
+ if self.parentqueue is not None:
+ self.parentqueue.put(('quit', None, None))
+
+ def add(self, name, seconds, callback, args=None, kwargs=None, repeat=False, qpointer=None):
+ self.addq.put(Task(name, seconds, callback, args, kwargs, repeat, qpointer))
+
+ def quit(self):
+ self.run = False
diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py
index 3f3f5e08..34513807 100644
--- a/sleekxmpp/xmlstream/stanzabase.py
+++ b/sleekxmpp/xmlstream/stanzabase.py
@@ -1,9 +1,9 @@
"""
- SleekXMPP: The Sleek XMPP Library
- Copyright (C) 2010 Nathanael C. Fritz
- This file is part of SleekXMPP.
+ SleekXMPP: The Sleek XMPP Library
+ Copyright (C) 2010 Nathanael C. Fritz
+ This file is part of SleekXMPP.
- See the file license.txt for copying permission.
+ See the file license.txt for copying permission.
"""
from xml.etree import cElementTree as ET
import logging
@@ -78,6 +78,9 @@ class ElementBase(tostring.ToString):
def __iter__(self):
self.idx = 0
return self
+
+ def __bool__(self):
+ return True
def __next__(self):
self.idx += 1
@@ -319,6 +322,8 @@ class StanzaBase(ElementBase):
def __init__(self, stream=None, xml=None, stype=None, sto=None, sfrom=None, sid=None):
self.stream = stream
+ if stream is not None:
+ self.namespace = stream.default_ns
ElementBase.__init__(self, xml)
if stype is not None:
self['type'] = stype
@@ -326,8 +331,6 @@ class StanzaBase(ElementBase):
self['to'] = sto
if sfrom is not None:
self['from'] = sfrom
- if stream is not None:
- self.namespace = stream.default_ns
self.tag = "{%s}%s" % (self.namespace, self.name)
def setType(self, value):
@@ -380,6 +383,7 @@ class StanzaBase(ElementBase):
def exception(self, e):
logging.error(traceback.format_tb(e))
- def send(self):
- self.stream.sendRaw(self.__str__())
-
+ def send(self, priority=False):
+ if priority: self.stream.sendPriorityRaw(self.__str__())
+ else: self.stream.sendRaw(self.__str__())
+
diff --git a/sleekxmpp/xmlstream/statemachine.py b/sleekxmpp/xmlstream/statemachine.py
index fb7d1508..67b514a2 100644
--- a/sleekxmpp/xmlstream/statemachine.py
+++ b/sleekxmpp/xmlstream/statemachine.py
@@ -7,53 +7,228 @@
"""
from __future__ import with_statement
import threading
+import time
+import logging
+
class StateMachine(object):
- def __init__(self, states=[], groups=[]):
- self.lock = threading.Lock()
- self.__state = {}
- self.__default_state = {}
- self.__group = {}
+ def __init__(self, states=[]):
+ self.lock = threading.Condition(threading.RLock())
+ self.__states= []
self.addStates(states)
- self.addGroups(groups)
+ self.__default_state = self.__states[0]
+ self.__current_state = self.__default_state
def addStates(self, states):
with self.lock:
for state in states:
- if state in self.__state or state in self.__group:
- raise IndexError("The state or group '%s' is already in the StateMachine." % state)
- self.__state[state] = states[state]
- self.__default_state[state] = states[state]
+ if state in self.__states:
+ raise IndexError("The state '%s' is already in the StateMachine." % state)
+ self.__states.append( state )
- def addGroups(self, groups):
- with self.lock:
- for gstate in groups:
- if gstate in self.__state or gstate in self.__group:
- raise IndexError("The key or group '%s' is already in the StateMachine." % gstate)
- for state in groups[gstate]:
- if state in self.__state:
- raise IndexError("The group %s contains a key %s which is not set in the StateMachine." % (gstate, state))
- self.__group[gstate] = groups[gstate]
-
- def set(self, state, status):
+
+ def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
+ '''
+ Transition from the given `from_state` to the given `to_state`.
+ This method will return `True` if the state machine is now in `to_state`. It
+ will return `False` if a timeout occurred the transition did not occur.
+ If `wait` is 0 (the default,) this method returns immediately if the state machine
+ is not in `from_state`.
+
+ If you want the thread to block and transition once the state machine to enters
+ `from_state`, set `wait` to a non-negative value. Note there is no 'block
+ indefinitely' flag since this leads to deadlock. If you want to wait indefinitely,
+ choose a reasonable value for `wait` (e.g. 20 seconds) and do so in a while loop like so:
+
+ ::
+
+ while not thread_should_exit and not state_machine.transition('disconnected', 'connecting', wait=20 ):
+ pass # timeout will occur every 20s unless transition occurs
+ if thread_should_exit: return
+ # perform actions here after successful transition
+
+ This allows the thread to be responsive by setting `thread_should_exit=True`.
+
+ The optional `func` argument allows the user to pass a callable operation which occurs
+ within the context of the state transition (e.g. while the state machine is locked.)
+ If `func` returns a True value, the transition will occur. If `func` returns a non-
+ True value or if an exception is thrown, the transition will not occur. Any thrown
+ exception is not caught by the state machine and is the caller's responsibility to handle.
+ If `func` completes normally, this method will return the value returned by `func.` If
+ values for `args` and `kwargs` are provided, they are expanded and passed like so:
+ `func( *args, **kwargs )`.
+ '''
+
+ return self.transition_any( (from_state,), to_state, wait=wait,
+ func=func, args=args, kwargs=kwargs )
+
+
+ def transition_any(self, from_states, to_state, wait=0.0, func=None, args=[], kwargs={} ):
+ '''
+ Transition from any of the given `from_states` to the given `to_state`.
+ '''
+
+ if not (isinstance(from_states,tuple) or isinstance(from_states,list)):
+ raise ValueError( "from_states should be a list or tuple" )
+
+ for state in from_states:
+ if not state in self.__states:
+ raise ValueError( "StateMachine does not contain from_state %s." % state )
+ if not to_state in self.__states:
+ raise ValueError( "StateMachine does not contain to_state %s." % to_state )
+
with self.lock:
- if state in self.__state:
- self.__state[state] = bool(status)
+ start = time.time()
+ while not self.__current_state in from_states:
+ # detect timeout:
+ if time.time() >= start + wait: return False
+ self.lock.wait(wait)
+
+ if self.__current_state in from_states: # should always be True due to lock
+
+ return_val = True
+ # Note that func might throw an exception, but that's OK, it aborts the transition
+ if func is not None: return_val = func(*args,**kwargs)
+
+ # some 'false' value returned from func,
+ # indicating that transition should not occur:
+ if not return_val: return return_val
+
+ logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
+ self.__current_state = to_state
+ self.lock.notify_all()
+ return return_val # some 'true' value returned by func or True if func was None
else:
- raise KeyError("StateMachine does not contain state %s." % state)
-
- def __getitem__(self, key):
- if key in self.__group:
- for state in self.__group[key]:
- if not self.__state[state]:
- return False
- return True
- return self.__state[key]
+ logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
+ return False
+
+
+ def transition_ctx(self, from_state, to_state, wait=0.0):
+ '''
+ Use the state machine as a context manager. The transition occurs on /exit/ from
+ the `with` context, so long as no exception is thrown. For example:
+
+ ::
+
+ with state_machine.transition_ctx('one','two', wait=5) as locked:
+ if locked:
+ # the state machine is currently locked in state 'one', and will
+ # transition to 'two' when the 'with' statement ends, so long as
+ # no exception is thrown.
+ print 'Currently locked in state one: %s' % state_machine['one']
+
+ else:
+ # The 'wait' timed out, and no lock has been acquired
+ print 'Timed out before entering state "one"'
+
+ print 'Since no exception was thrown, we are now in state "two": %s' % state_machine['two']
+
+
+ The other main difference between this method and `transition()` is that the
+ state machine is locked for the duration of the `with` statement. Normally,
+ after a `transition()` occurs, the state machine is immediately unlocked and
+ available to another thread to call `transition()` again.
+ '''
+
+ if not from_state in self.__states:
+ raise ValueError( "StateMachine does not contain from_state %s." % from_state )
+ if not to_state in self.__states:
+ raise ValueError( "StateMachine does not contain to_state %s." % to_state )
+
+ return _StateCtx(self, from_state, to_state, wait)
+
- def __getattr__(self, attr):
- return self.__getitem__(attr)
+ def ensure(self, state, wait=0.0):
+ '''
+ Ensure the state machine is currently in `state`, or wait until it enters `state`.
+ '''
+ return self.ensure_any( (state,), wait=wait )
+
+
+ def ensure_any(self, states, wait=0.0):
+ '''
+ Ensure we are currently in one of the given `states`
+ '''
+ if not (isinstance(states,tuple) or isinstance(states,list)):
+ raise ValueError('states arg should be a tuple or list')
+
+ for state in states:
+ if not state in self.__states:
+ raise ValueError( "StateMachine does not contain state '%s'" % state )
+
+ with self.lock:
+ start = time.time()
+ while not self.__current_state in states:
+ # detect timeout:
+ if time.time() >= start + wait: return False
+ self.lock.wait(wait)
+ return self.__current_state in states # should always be True due to lock
+
def reset(self):
- self.__state = self.__default_state
+ # TODO need to lock before calling this?
+ self.transition(self.__current_state, self._default_state)
+
+
+ def _set_state(self, state): #unsynchronized, only call internally after lock is acquired
+ self.__current_state = state
+ return state
+
+
+ def current_state(self):
+ '''
+ Return the current state name.
+ '''
+ return self.__current_state
+
+
+ def __getitem__(self, state):
+ '''
+ Non-blocking, non-synchronized test to determine if we are in the given state.
+ Use `StateMachine.ensure(state)` to wait until the machine enters a certain state.
+ '''
+ return self.__current_state == state
+
+ def __str__(self):
+ return "".join(( "StateMachine(", ','.join(self.__states), "): ", self.__current_state ))
+
+
+
+class _StateCtx:
+
+ def __init__( self, state_machine, from_state, to_state, wait ):
+ self.state_machine = state_machine
+ self.from_state = from_state
+ self.to_state = to_state
+ self.wait = wait
+ self._timeout = False
+
+ def __enter__(self):
+ self.state_machine.lock.acquire()
+ start = time.time()
+ while not self.state_machine[ self.from_state ]:
+ # detect timeout:
+ if time.time() >= start + self.wait:
+ logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
+ self._timeout = True # to indicate we should not transition
+ return False
+ self.state_machine.lock.wait(self.wait)
+
+ logging.debug('StateMachine entered context in state: %s',
+ self.state_machine.current_state() )
+ return True
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if exc_val is not None:
+ logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
+ self.state_machine.current_state(), exc_type.__name__, exc_val )
+ elif not self._timeout:
+ logging.debug(' ==== TRANSITION %s -> %s',
+ self.state_machine.current_state(), self.to_state)
+ self.state_machine._set_state( self.to_state )
+
+ self.state_machine.lock.notify_all()
+ self.state_machine.lock.release()
+ return False # re-raise any exception
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index 025884b7..a8bcac00 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -1,9 +1,9 @@
"""
- SleekXMPP: The Sleek XMPP Library
- Copyright (C) 2010 Nathanael C. Fritz
- This file is part of SleekXMPP.
+ SleekXMPP: The Sleek XMPP Library
+ Copyright (C) 2010 Nathanael C. Fritz
+ This file is part of SleekXMPP.
- See the file license.txt for copying permission.
+ See the file license.txt for copying permission.
"""
from __future__ import with_statement, unicode_literals
@@ -16,12 +16,14 @@ from . stanzabase import StanzaBase
from xml.etree import cElementTree
from xml.parsers import expat
import logging
+import random
import socket
import threading
import time
import traceback
import types
import xml.sax.saxutils
+from . import scheduler
HANDLER_THREADS = 1
@@ -45,6 +47,10 @@ class CloseStream(Exception):
stanza_extensions = {}
+RECONNECT_MAX_DELAY = 3600
+RECONNECT_QUIESCE_FACTOR = 1.6180339887498948 # Phi
+RECONNECT_QUIESCE_JITTER = 0.11962656472 # molar Planck constant times c, joule meter/mole
+
class XMLStream(object):
"A connection manager with XML events."
@@ -52,8 +58,9 @@ class XMLStream(object):
global ssl_support
self.ssl_support = ssl_support
self.escape_quotes = escape_quotes
- self.state = statemachine.StateMachine()
- self.state.addStates({'connected':False, 'is client':False, 'ssl':False, 'tls':False, 'reconnect':True, 'processing':False, 'disconnecting':False}) #set initial states
+ self.state = statemachine.StateMachine(('disconnected','connecting',
+ 'connected'))
+ self.should_reconnect = True
self.setSocket(socket)
self.address = (host, int(port))
@@ -69,12 +76,14 @@ class XMLStream(object):
self.filesocket = None
self.use_ssl = False
self.use_tls = False
+ self.ca_certs=None
self.stream_header = "<stream>"
self.stream_footer = "</stream>"
self.eventqueue = queue.Queue()
- self.sendqueue = queue.Queue()
+ self.sendqueue = queue.PriorityQueue()
+ self.scheduler = scheduler.Scheduler(self.eventqueue)
self.namespace_map = {}
@@ -83,45 +92,77 @@ class XMLStream(object):
def setSocket(self, socket):
"Set the socket"
self.socket = socket
- if socket is not None:
+ if socket is not None and self.state.transition('disconnected','connecting'):
self.filesocket = socket.makefile('rb', 0) # ElementTree.iterparse requires a file. 0 buffer files have to be binary
- self.state.set('connected', True)
-
+ self.state.transition('connecting','connected')
def setFileSocket(self, filesocket):
self.filesocket = filesocket
- def connect(self, host='', port=0, use_ssl=False, use_tls=True):
- "Link to connectTCP"
- return self.connectTCP(host, port, use_ssl, use_tls)
+ def connect(self, host='', port=0, use_ssl=None, use_tls=None):
+ "Establish a socket connection to the given XMPP server."
+
+ if not self.state.transition('disconnected','connected',
+ func=self.connectTCP, args=[host, port, use_ssl, use_tls] ):
+
+ if self.state['connected']: logging.debug('Already connected')
+ else: logging.warning("Connection failed" )
+ return False
+
+ logging.debug('Connection complete.')
+ return True
+
+ # TODO currently a caller can't distinguish between "connection failed" and
+ # "we're already trying to connect from another thread"
def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True):
"Connect and create socket"
- while reattempt and not self.state['connected']:
- if host and port:
- self.address = (host, int(port))
- if use_ssl is not None:
- self.use_ssl = use_ssl
- if use_tls is not None:
- self.use_tls = use_tls
- self.state.set('is client', True)
- if sys.version_info < (3, 0):
- self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM)
- else:
- self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.socket.settimeout(None)
- if self.use_ssl and self.ssl_support:
- logging.debug("Socket Wrapped for SSL")
- self.socket = ssl.wrap_socket(self.socket)
+
+ # Note that this is thread-safe by merit of being called solely from connect() which
+ # holds the state lock.
+
+ delay = 1.0 # reconnection delay
+ while self.run:
+ logging.debug('connecting....')
try:
+ if host and port:
+ self.address = (host, int(port))
+ if use_ssl is not None:
+ self.use_ssl = use_ssl
+ if use_tls is not None:
+ # TODO this variable doesn't seem to be used for anything!
+ self.use_tls = use_tls
+ if sys.version_info < (3, 0):
+ self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM)
+ else:
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.socket.settimeout(None) #10)
+
+ if self.use_ssl and self.ssl_support:
+ logging.debug("Socket Wrapped for SSL")
+ self.socket = ssl.wrap_socket(self.socket,ca_certs=self.ca_certs)
+
self.socket.connect(self.address)
- #self.filesocket = self.socket.makefile('rb', 0)
self.filesocket = self.socket.makefile('rb', 0)
- self.state.set('connected', True)
+
return True
+
except socket.error as serr:
- logging.error("Could not connect. Socket Error #%s: %s" % (serr.errno, serr.strerror))
- time.sleep(1)
+ logging.exception("Socket Error #%s: %s", serr.errno, serr.strerror)
+ if not reattempt: return False
+ except:
+ logging.exception("Connection error")
+ if not reattempt: return False
+
+ # quiesce if rconnection fails:
+ # This algorithm based loosely on Twisted internet.protocol
+ # http://twistedmatrix.com/trac/browser/trunk/twisted/internet/protocol.py#L310
+ delay = min(delay * RECONNECT_QUIESCE_FACTOR, RECONNECT_MAX_DELAY)
+ delay = random.normalvariate(delay, delay * RECONNECT_QUIESCE_JITTER)
+ logging.debug('Waiting %fs until next reconnect attempt...', delay)
+ time.sleep(delay)
+
+
def connectUnix(self, filepath):
"Connect to Unix file and create socket"
@@ -130,14 +171,19 @@ class XMLStream(object):
"Handshakes for TLS"
if self.ssl_support:
logging.info("Negotiating TLS")
- self.realsocket = self.socket
- self.socket = ssl.wrap_socket(self.socket, ssl_version=ssl.PROTOCOL_TLSv1, do_handshake_on_connect=False)
+# self.realsocket = self.socket # NOT USED
+ self.socket = ssl.wrap_socket(self.socket,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ do_handshake_on_connect=False,
+ ca_certs=self.ca_certs)
self.socket.do_handshake()
if sys.version_info < (3,0):
from . filesocket import filesocket
self.filesocket = filesocket(self.socket)
else:
self.filesocket = self.socket.makefile('rb', 0)
+
+ logging.debug("TLS negotitation successful")
return True
else:
logging.warning("Tried to enable TLS, but ssl module not found.")
@@ -145,67 +191,56 @@ class XMLStream(object):
raise RestartStream()
def process(self, threaded=True):
+ self.scheduler.process(threaded=True)
+ self.run = True
for t in range(0, HANDLER_THREADS):
- self.__thread['eventhandle%s' % t] = threading.Thread(name='eventhandle%s' % t, target=self._eventRunner)
- self.__thread['eventhandle%s' % t].start()
- self.__thread['sendthread'] = threading.Thread(name='sendthread', target=self._sendThread)
- self.__thread['sendthread'].start()
+ th = threading.Thread(name='eventhandle%s' % t, target=self._eventRunner)
+ th.setDaemon(True)
+ self.__thread['eventhandle%s' % t] = th
+ th.start()
+ th = threading.Thread(name='sendthread', target=self._sendThread)
+ th.setDaemon(True)
+ self.__thread['sendthread'] = th
+ th.start()
if threaded:
- self.__thread['process'] = threading.Thread(name='process', target=self._process)
- self.__thread['process'].start()
+ th = threading.Thread(name='process', target=self._process)
+ th.setDaemon(True)
+ self.__thread['process'] = th
+ th.start()
else:
self._process()
- def schedule(self, seconds, handler, args=None):
- threading.Timer(seconds, handler, args).start()
+ def schedule(self, name, seconds, callback, args=None, kwargs=None, repeat=False):
+ self.scheduler.add(name, seconds, callback, args, kwargs, repeat, qpointer=self.eventqueue)
def _process(self):
"Start processing the socket."
- firstrun = True
- while self.run and (firstrun or self.state['reconnect']):
- self.state.set('processing', True)
- firstrun = False
+ logging.debug('Process thread starting...')
+ while self.run:
+ if not self.state.ensure('connected',wait=2): continue
try:
- if self.state['is client']:
- self.sendRaw(self.stream_header)
- while self.run and self.__readXML():
- if self.state['is client']:
- self.sendRaw(self.stream_header)
- except KeyboardInterrupt:
- logging.debug("Keyboard Escape Detected")
- self.state.set('processing', False)
- self.state.set('reconnect', False)
- self.disconnect()
- self.run = False
- self.eventqueue.put(('quit', None, None))
- return
+ self.sendPriorityRaw(self.stream_header)
+ while self.run and self.__readXML(): pass
+ except socket.timeout:
+ logging.debug('socket rcv timeout')
+ pass
except CloseStream:
- return
- except SystemExit:
+ # TODO warn that the listener thread is exiting!!!
+ pass
+ except RestartStream:
+ logging.debug("Restarting stream...")
+ continue # DON'T re-initialize the stream -- this exception is sent
+ # specifically when we've initialized TLS and need to re-send the <stream> header.
+ except (KeyboardInterrupt, SystemExit):
+ logging.debug("System interrupt detected")
+ self.shutdown()
self.eventqueue.put(('quit', None, None))
- return
- except socket.error:
- if not self.state.reconnect:
- return
- else:
- self.state.set('processing', False)
- traceback.print_exc()
- self.disconnect(reconnect=True)
except:
- if not self.state.reconnect:
- return
- else:
- self.state.set('processing', False)
- traceback.print_exc()
+ logging.exception('Unexpected error in RCV thread')
+ if self.should_reconnect:
self.disconnect(reconnect=True)
- if self.state['reconnect']:
- self.reconnect()
- self.state.set('processing', False)
- self.eventqueue.put(('quit', None, None))
- #self.__thread['readXML'] = threading.Thread(name='readXML', target=self.__readXML)
- #self.__thread['readXML'].start()
- #self.__thread['spawnEvents'] = threading.Thread(name='spawnEvents', target=self.__spawnEvents)
- #self.__thread['spawnEvents'].start()
+
+ logging.debug('Quitting Process thread')
def __readXML(self):
"Parses the incoming stream, adding to xmlin queue as it goes"
@@ -218,82 +253,94 @@ class XMLStream(object):
if edepth == 0: # and xmlobj.tag.split('}', 1)[-1] == self.basetag:
if event == b'start':
root = xmlobj
+ logging.debug('handling start stream')
self.start_stream_handler(root)
if event == b'end':
edepth += -1
if edepth == 0 and event == b'end':
- self.disconnect(reconnect=self.state['reconnect'])
+ # what is this case exactly? Premature EOF?
+ logging.debug("Ending readXML loop")
return False
elif edepth == 1:
#self.xmlin.put(xmlobj)
- try:
- self.__spawnEvent(xmlobj)
- except RestartStream:
- return True
- except CloseStream:
- return False
- if root:
- root.clear()
+ self.__spawnEvent(xmlobj)
+ if root: root.clear()
if event == b'start':
edepth += 1
+ logging.debug("Exiting readXML loop")
+ return False
def _sendThread(self):
+ logging.debug('send thread starting...')
while self.run:
- data = self.sendqueue.get(True)
- logging.debug("SEND: %s" % data)
+ if not self.state.ensure('connected',wait=2): continue
+
+ data = None
try:
- self.socket.send(data.encode('utf-8'))
- #self.socket.send(bytes(data, "utf-8"))
- #except socket.error,(errno, strerror):
+ data = self.sendqueue.get(True,5)[1]
+ logging.debug("SEND: %s" % data)
+ self.socket.sendall(data.encode('utf-8'))
+ except queue.Empty:
+# logging.debug('Nothing on send queue')
+ pass
+ except socket.timeout:
+ # this is to prevent a thread blocked indefinitely
+ logging.debug('timeout sending packet data')
except:
logging.warning("Failed to send %s" % data)
- self.state.set('connected', False)
- if self.state.reconnect:
- logging.error("Disconnected. Socket Error.")
- traceback.print_exc()
+ logging.exception("Socket error in SEND thread")
+ # TODO it's somewhat unsafe for the sender thread to assume it can just
+ # re-intitialize the connection, since the receiver thread could be doing
+ # the same thing concurrently. Oops! The safer option would be to throw
+ # some sort of event that could be handled by a common thread or the reader
+ # thread to perform reconnect and then re-initialize the handler threads as well.
+ if self.should_reconnect:
self.disconnect(reconnect=True)
def sendRaw(self, data):
- self.sendqueue.put(data)
+ self.sendqueue.put((1, data))
+ return True
+
+ def sendPriorityRaw(self, data):
+ self.sendqueue.put((0, data))
return True
def disconnect(self, reconnect=False):
- self.state.set('reconnect', reconnect)
- if self.state['disconnecting']:
+ if not self.state.transition('connected','disconnected'):
+ logging.warning("Already disconnected.")
return
- if not self.state['reconnect']:
- logging.debug("Disconnecting...")
- self.state.set('disconnecting', True)
- self.run = False
- if self.state['connected']:
- self.sendRaw(self.stream_footer)
- time.sleep(1)
- #send end of stream
- #wait for end of stream back
+ logging.debug("Disconnecting...")
+ self.sendPriorityRaw(self.stream_footer)
+ time.sleep(5)
+ #send end of stream
+ #wait for end of stream back
try:
+# self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
+ except socket.error as (errno,strerror):
+ logging.exception("Error while disconnecting. Socket Error #%s: %s" % (errno, strerror))
+ try:
self.filesocket.close()
- self.socket.shutdown(socket.SHUT_RDWR)
- except socket.error as serr:
- #logging.warning("Error while disconnecting. Socket Error #%s: %s" % (errno, strerror))
- #thread.exit_thread()
- pass
- if self.state['processing']:
- #raise CloseStream
- pass
-
- def reconnect(self):
- self.state.set('tls',False)
- self.state.set('ssl',False)
- time.sleep(1)
- self.connect()
+ except socket.error as (errno,strerror):
+ logging.exception("Error closing filesocket.")
+
+ if reconnect: self.connect()
+ def shutdown(self):
+ '''
+ Disconnects and shuts down all event threads.
+ '''
+ self.disconnect()
+ self.run = False
+ self.scheduler.run = False
+
def incoming_filter(self, xmlobj):
return xmlobj
-
+
def __spawnEvent(self, xmlobj):
"watching xmlOut and processes handlers"
#convert XML into Stanza
+ # TODO surround this log statement with an if, it's expensive
logging.debug("RECV: %s" % cElementTree.tostring(xmlobj))
xmlobj = self.incoming_filter(xmlobj)
stanza = None
@@ -305,48 +352,54 @@ class XMLStream(object):
if stanza is None:
stanza = StanzaBase(self, xmlobj)
unhandled = True
+ # TODO inefficient linear search; performance might be improved by hashtable lookup
for handler in self.__handlers:
if handler.match(stanza):
+ logging.debug('matched stanza to handler %s', handler.name)
handler.prerun(stanza)
self.eventqueue.put(('stanza', handler, stanza))
- if handler.checkDelete(): self.__handlers.pop(self.__handlers.index(handler))
+ if handler.checkDelete():
+ logging.debug('deleting callback %s', handler.name)
+ self.__handlers.pop(self.__handlers.index(handler))
unhandled = False
if unhandled:
stanza.unhandled()
#loop through handlers and test match
#spawn threads as necessary, call handlers, sending Stanza
-
+
def _eventRunner(self):
logging.debug("Loading event runner")
while self.run:
try:
event = self.eventqueue.get(True, timeout=5)
except queue.Empty:
+# logging.debug('Nothing on event queue')
event = None
if event is not None:
etype = event[0]
handler = event[1]
args = event[2:]
- #etype, handler, *args = event #python 3.x way
+ #etype, handler, *args = event #python 3.x way
if etype == 'stanza':
try:
handler.run(args[0])
except Exception as e:
- traceback.print_exc()
+ logging.exception("Exception in event handler")
args[0].exception(e)
elif etype == 'sched':
try:
+ #handler(*args[0])
handler.run(*args)
except:
logging.error(traceback.format_exc())
elif etype == 'quit':
logging.debug("Quitting eventRunner thread")
return False
-
+
def registerHandler(self, handler, before=None, after=None):
"Add handler with matcher class and parameters."
self.__handlers.append(handler)
-
+
def removeHandler(self, name):
"Removes the handler."
idx = 0
@@ -432,4 +485,4 @@ class XMLStream(object):
def start_stream_handler(self, xml):
"""Meant to be overridden"""
- pass
+ logging.warn("No start stream handler has been implemented.")
diff --git a/tests/test_disco.py b/tests/test_disco.py
new file mode 100644
index 00000000..bbe285a6
--- /dev/null
+++ b/tests/test_disco.py
@@ -0,0 +1,155 @@
+import unittest
+from xml.etree import cElementTree as ET
+from sleekxmpp.xmlstream.matcher.stanzapath import StanzaPath
+from . import xmlcompare
+
+import sleekxmpp.plugins.xep_0030 as sd
+
+def stanzaPlugin(stanza, plugin):
+ stanza.plugin_attrib_map[plugin.plugin_attrib] = plugin
+ stanza.plugin_tag_map["{%s}%s" % (plugin.namespace, plugin.name)] = plugin
+
+class testdisco(unittest.TestCase):
+
+ def setUp(self):
+ self.sd = sd
+ stanzaPlugin(self.sd.Iq, self.sd.DiscoInfo)
+ stanzaPlugin(self.sd.Iq, self.sd.DiscoItems)
+
+ def try3Methods(self, xmlstring, iq):
+ iq2 = self.sd.Iq(None, self.sd.ET.fromstring(xmlstring))
+ values = iq2.getValues()
+ iq3 = self.sd.Iq()
+ iq3.setValues(values)
+ self.failUnless(xmlstring == str(iq) == str(iq2) == str(iq3), str(iq)+"3 methods for creating stanza don't match")
+
+ def testCreateInfoQueryNoNode(self):
+ """Testing disco#info query with no node."""
+ iq = self.sd.Iq()
+ iq['id'] = "0"
+ iq['disco_info']['node'] = ''
+ xmlstring = """<iq id="0"><query xmlns="http://jabber.org/protocol/disco#info" /></iq>"""
+ self.try3Methods(xmlstring, iq)
+
+ def testCreateInfoQueryWithNode(self):
+ """Testing disco#info query with a node."""
+ iq = self.sd.Iq()
+ iq['id'] = "0"
+ iq['disco_info']['node'] = 'foo'
+ xmlstring = """<iq id="0"><query xmlns="http://jabber.org/protocol/disco#info" node="foo" /></iq>"""
+ self.try3Methods(xmlstring, iq)
+
+ def testCreateInfoQueryNoNode(self):
+ """Testing disco#items query with no node."""
+ iq = self.sd.Iq()
+ iq['id'] = "0"
+ iq['disco_items']['node'] = ''
+ xmlstring = """<iq id="0"><query xmlns="http://jabber.org/protocol/disco#items" /></iq>"""
+ self.try3Methods(xmlstring, iq)
+
+ def testCreateItemsQueryWithNode(self):
+ """Testing disco#items query with a node."""
+ iq = self.sd.Iq()
+ iq['id'] = "0"
+ iq['disco_items']['node'] = 'foo'
+ xmlstring = """<iq id="0"><query xmlns="http://jabber.org/protocol/disco#items" node="foo" /></iq>"""
+ self.try3Methods(xmlstring, iq)
+
+ def testInfoIdentities(self):
+ """Testing adding identities to disco#info."""
+ iq = self.sd.Iq()
+ iq['id'] = "0"
+ iq['disco_info']['node'] = 'foo'
+ iq['disco_info'].addIdentity('conference', 'text', 'Chatroom')
+ xmlstring = """<iq id="0"><query xmlns="http://jabber.org/protocol/disco#info" node="foo"><identity category="conference" type="text" name="Chatroom" /></query></iq>"""
+ self.try3Methods(xmlstring, iq)
+
+ def testInfoFeatures(self):
+ """Testing adding features to disco#info."""
+ iq = self.sd.Iq()
+ iq['id'] = "0"
+ iq['disco_info']['node'] = 'foo'
+ iq['disco_info'].addFeature('foo')
+ iq['disco_info'].addFeature('bar')
+ xmlstring = """<iq id="0"><query xmlns="http://jabber.org/protocol/disco#info" node="foo"><feature var="foo" /><feature var="bar" /></query></iq>"""
+ self.try3Methods(xmlstring, iq)
+
+ def testItems(self):
+ """Testing adding features to disco#info."""
+ iq = self.sd.Iq()
+ iq['id'] = "0"
+ iq['disco_items']['node'] = 'foo'
+ iq['disco_items'].addItem('user@localhost')
+ iq['disco_items'].addItem('user@localhost', 'foo')
+ iq['disco_items'].addItem('user@localhost', 'bar', 'Testing')
+ xmlstring = """<iq id="0"><query xmlns="http://jabber.org/protocol/disco#items" node="foo"><item jid="user@localhost" /><item node="foo" jid="user@localhost" /><item node="bar" jid="user@localhost" name="Testing" /></query></iq>"""
+ self.try3Methods(xmlstring, iq)
+
+ def testAddRemoveIdentities(self):
+ """Test adding and removing identities to disco#info stanza"""
+ ids = [('automation', 'commands', 'AdHoc'),
+ ('conference', 'text', 'ChatRoom')]
+
+ info = self.sd.DiscoInfo()
+ info.addIdentity(*ids[0])
+ self.failUnless(info.getIdentities() == [ids[0]])
+
+ info.delIdentity('automation', 'commands')
+ self.failUnless(info.getIdentities() == [])
+
+ info.setIdentities(ids)
+ self.failUnless(info.getIdentities() == ids)
+
+ info.delIdentity('automation', 'commands')
+ self.failUnless(info.getIdentities() == [ids[1]])
+
+ info.delIdentities()
+ self.failUnless(info.getIdentities() == [])
+
+ def testAddRemoveFeatures(self):
+ """Test adding and removing features to disco#info stanza"""
+ features = ['foo', 'bar', 'baz']
+
+ info = self.sd.DiscoInfo()
+ info.addFeature(features[0])
+ self.failUnless(info.getFeatures() == [features[0]])
+
+ info.delFeature('foo')
+ self.failUnless(info.getFeatures() == [])
+
+ info.setFeatures(features)
+ self.failUnless(info.getFeatures() == features)
+
+ info.delFeature('bar')
+ self.failUnless(info.getFeatures() == ['foo', 'baz'])
+
+ info.delFeatures()
+ self.failUnless(info.getFeatures() == [])
+
+ def testAddRemoveItems(self):
+ """Test adding and removing items to disco#items stanza"""
+ items = [('user@localhost', None, None),
+ ('user@localhost', 'foo', None),
+ ('user@localhost', 'bar', 'Test')]
+
+ info = self.sd.DiscoItems()
+ self.failUnless(True, ""+str(items[0]))
+
+ info.addItem(*(items[0]))
+ self.failUnless(info.getItems() == [items[0]], info.getItems())
+
+ info.delItem('user@localhost')
+ self.failUnless(info.getItems() == [])
+
+ info.setItems(items)
+ self.failUnless(info.getItems() == items)
+
+ info.delItem('user@localhost', 'foo')
+ self.failUnless(info.getItems() == [items[0], items[2]])
+
+ info.delItems()
+ self.failUnless(info.getItems() == [])
+
+
+
+suite = unittest.TestLoader().loadTestsFromTestCase(testdisco)
diff --git a/tests/test_pubsubstanzas.py b/tests/test_pubsubstanzas.py
index 55407c16..089ee180 100644
--- a/tests/test_pubsubstanzas.py
+++ b/tests/test_pubsubstanzas.py
@@ -97,6 +97,21 @@ class testpubsubstanzas(unittest.TestCase):
iq3.setValues(values)
self.failUnless(xmlstring == str(iq) == str(iq2) == str(iq3))
+ def testState(self):
+ "Testing iq/psstate stanzas"
+ from sleekxmpp.plugins import xep_0004
+ iq = self.ps.Iq()
+ iq['psstate']['node']= 'mynode'
+ iq['psstate']['item']= 'myitem'
+ pl = ET.Element('{http://andyet.net/protocol/pubsubqueue}claimed')
+ iq['psstate']['payload'] = pl
+ xmlstring = """<iq id="0"><state xmlns="http://jabber.org/protocol/psstate" node="mynode" item="myitem"><claimed xmlns="http://andyet.net/protocol/pubsubqueue" /></state></iq>"""
+ iq2 = self.ps.Iq(None, self.ps.ET.fromstring(xmlstring))
+ iq3 = self.ps.Iq()
+ values = iq2.getValues()
+ iq3.setValues(values)
+ self.failUnless(xmlstring == str(iq) == str(iq2) == str(iq3))
+
def testDefault(self):
"Testing iq/pubsub_owner/default stanzas"
from sleekxmpp.plugins import xep_0004
diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py
new file mode 100644
index 00000000..e44b8e48
--- /dev/null
+++ b/tests/test_statemachine.py
@@ -0,0 +1,261 @@
+import unittest
+import time, threading, random, functools
+
+if __name__ == '__main__':
+ import sys, os
+ sys.path.insert(0, os.getcwd())
+ import sleekxmpp.xmlstream.statemachine as sm
+
+
+class testStateMachine(unittest.TestCase):
+
+ def setUp(self): pass
+
+
+ def testDefaults(self):
+ "Test ensure transitions occur correctly in a single thread"
+ s = sm.StateMachine(('one','two','three'))
+ self.assertTrue(s['one'])
+ self.failIf(s['two'])
+ try:
+ s['booga']
+ self.fail('s.booga is an invalid state and should throw an exception!')
+ except: pass #expected exception
+
+ # just make sure __str__ works, no reason to test its exact value:
+ print str(s)
+
+
+ def testTransitions(self):
+ "Test ensure transitions occur correctly in a single thread"
+ s = sm.StateMachine(('one','two','three'))
+
+ self.assertTrue( s.transition('one', 'two') )
+ self.assertTrue( s['two'] )
+ self.failIf( s['one'] )
+
+ self.assertTrue( s.transition('two', 'three') )
+ self.assertTrue( s['three'] )
+ self.failIf( s['two'] )
+
+ self.assertTrue( s.transition('three', 'one') )
+ self.assertTrue( s['one'] )
+ self.failIf( s['three'] )
+
+ # should return False immediately w/ no wait:
+ self.failIf( s.transition('three', 'one') )
+ self.assertTrue( s['one'] )
+ self.failIf( s['three'] )
+
+ # test fail condition w/ a short delay:
+ self.failIf( s.transition('two', 'three') )
+
+ # Ensure bad states are weeded out:
+ try:
+ s.transition('blah', 'three')
+ s.fail('Exception expected')
+ except: pass
+
+ try:
+ s.transition('one', 'blahblah')
+ s.fail('Exception expected')
+ except: pass
+
+
+ def testTransitionsBlocking(self):
+ "Test that transitions block from more than one thread"
+
+ s = sm.StateMachine(('one','two','three'))
+ self.assertTrue(s['one'])
+
+ now = time.time()
+ self.failIf( s.transition('two', 'one', wait=5.0) )
+ self.assertTrue( time.time() > now + 4 )
+ self.assertTrue( time.time() < now + 7 )
+
+ def testThreadedTransitions(self):
+ "Test that transitions are atomic in > one thread"
+
+ s = sm.StateMachine(('one','two','three'))
+ self.assertTrue(s['one'])
+
+ thread_state = {'ready': False, 'transitioned': False}
+ def t1():
+ if s['two']:
+ print 'thread has already transitioned!'
+ self.fail()
+ thread_state['ready'] = True
+ print 'Thread is ready'
+ # this will block until the main thread transitions to 'two'
+ self.assertTrue( s.transition('two','three', wait=20) )
+ print 'transitioned to three!'
+ thread_state['transitioned'] = True
+
+ thread = threading.Thread(target=t1)
+ thread.daemon = True
+ thread.start()
+ start = time.time()
+ while not thread_state['ready']:
+ print 'not ready'
+ if time.time() > start+10: self.fail('Timeout waiting for thread to init!')
+ time.sleep(0.1)
+ time.sleep(0.2) # the thread should be blocking on the 'transition' call at this point.
+ self.failIf( thread_state['transitioned'] ) # ensure it didn't 'go' yet.
+ print 'transitioning to two!'
+ self.assertTrue( s.transition('one','two') )
+ time.sleep(0.2) # second thread should have transitioned now:
+ self.assertTrue( thread_state['transitioned'] )
+
+
+ def testForRaceCondition(self):
+ """Attempt to allow two threads to perform the same transition;
+ only one should ever make it."""
+
+ s = sm.StateMachine(('one','two','three'))
+
+ def t1(num):
+ while True:
+ if not trigger['go'] or thread_state[num] in (True,False):
+ time.sleep( random.random()/100 ) # < .01s
+ if thread_state[num] == 'quit': break
+ continue
+
+ thread_state[num] = s.transition('one','two' )
+# print '-',
+
+ thread_count = 20
+ threads = []
+ thread_state = {}
+ def reset():
+ for c in range(thread_count): thread_state[c] = "reset"
+ trigger = {'go':False} # use of a plain boolean seems to be non-volatile between threads.
+
+ for c in range(thread_count):
+ thread_state[c] = "reset"
+ thread = threading.Thread( target= functools.partial(t1,c) )
+ threads.append( thread )
+ thread.daemon = True
+ thread.start()
+
+ for x in range(100): # this will take 10s to execute
+# print "+",
+ trigger['go'] = True
+ time.sleep(.1)
+ trigger['go'] = False
+ winners = 0
+ for (num, state) in thread_state.items():
+ if state == True: winners = winners +1
+ elif state != False: raise Exception( "!%d!%s!" % (num,state) )
+
+ self.assertEqual( 1, winners, "Expected one winner! %d" % winners )
+ self.assertTrue( s.ensure('two') )
+ self.assertTrue( s.transition('two','one') ) # return to the first state.
+ reset()
+
+ # now let the threads quit gracefully:
+ for c in range(thread_count): thread_state[c] = 'quit'
+ time.sleep(2)
+
+
+ def testTransitionFunctions(self):
+ "test that a `func` argument allows or blocks the transition correctly."
+
+ s = sm.StateMachine(('one','two','three'))
+
+ def alwaysFalse(): return False
+ def alwaysTrue(): return True
+
+ self.failIf( s.transition('one','two', func=alwaysFalse) )
+ self.assertTrue(s['one'])
+ self.failIf(s['two'])
+
+ self.assertTrue( s.transition('one','two', func=alwaysTrue) )
+ self.failIf(s['one'])
+ self.assertTrue(s['two'])
+
+
+ def testTransitionFuncException(self):
+ "if a transition function throws an exeption, ensure we're in a sane state"
+
+ s = sm.StateMachine(('one','two','three'))
+
+ def alwaysException(): raise Exception('whups!')
+
+ try:
+ self.failIf( s.transition('one','two', func=alwaysException) )
+ self.fail("exception should have been thrown")
+ except: pass #expected exception
+
+ self.assertTrue(s['one'])
+ self.failIf(s['two'])
+
+ # ensure a subsequent attempt completes normally:
+ self.assertTrue( s.transition('one','two') )
+ self.failIf(s['one'])
+ self.assertTrue(s['two'])
+
+
+ def testContextManager(self):
+
+ s = sm.StateMachine(('one','two','three'))
+
+ with s.transition_ctx('one','two'):
+ self.assertTrue( s['one'] )
+ self.failIf( s['two'] )
+
+ #successful transition b/c no exception was thrown
+ self.assertTrue( s['two'] )
+ self.failIf( s['one'] )
+
+ # failed transition because exception is thrown:
+ try:
+ with s.transition_ctx('two','three'):
+ raise Exception("boom!")
+ self.fail('exception expected')
+ except: pass
+
+ self.failIf( s.current_state() in ('one','three') )
+ self.assertTrue( s['two'] )
+
+ def testCtxManagerTransitionFailure(self):
+
+ s = sm.StateMachine(('one','two','three'))
+
+ with s.transition_ctx('two','three') as result:
+ self.failIf( result )
+ self.assertTrue( s['one'] )
+ self.failIf( s.current_state in ('two','three') )
+
+ self.assertTrue( s['one'] )
+
+ def r1():
+ print 'thread 1 started'
+ self.assertTrue( s.transition('one','two') )
+ print 'thread 1 transitioned'
+
+ def r2():
+ print 'thread 2 started'
+ self.failIf( s['two'] )
+ with s.transition_ctx('two','three', 10) as result:
+ self.assertTrue( result )
+ self.assertTrue( s['two'] )
+ print 'thread 2 will transition on exit from the context manager...'
+ self.assertTrue( s['three'] )
+ print 'transitioned to %s' % s.current_state()
+
+ t1 = threading.Thread(target=r1)
+ t2 = threading.Thread(target=r2)
+
+ t2.start() # this should block until r1 goes
+ time.sleep(1)
+ t1.start()
+
+ t1.join()
+ t2.join()
+
+ self.assertTrue( s['three'] )
+
+
+suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
+
+if __name__ == '__main__': unittest.main()