diff options
Diffstat (limited to 'sleekxmpp')
-rw-r--r-- | sleekxmpp/__init__.py | 149 | ||||
-rw-r--r-- | sleekxmpp/basexmpp.py | 39 | ||||
-rw-r--r-- | sleekxmpp/plugins/jobs.py | 44 | ||||
-rw-r--r-- | sleekxmpp/plugins/stanza_pubsub.py | 45 | ||||
-rw-r--r-- | sleekxmpp/plugins/xep_0004.py | 1 | ||||
-rw-r--r-- | sleekxmpp/plugins/xep_0030.py | 405 | ||||
-rw-r--r-- | sleekxmpp/plugins/xep_0060.py | 10 | ||||
-rw-r--r-- | sleekxmpp/stanza/error.py | 10 | ||||
-rw-r--r-- | sleekxmpp/stanza/iq.py | 14 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/handler/base.py | 2 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/handler/callback.py | 4 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/scheduler.py | 87 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/stanzabase.py | 22 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/statemachine.py | 245 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 323 |
15 files changed, 1048 insertions, 352 deletions
diff --git a/sleekxmpp/__init__.py b/sleekxmpp/__init__.py index 263f1f99..b62e35db 100644 --- a/sleekxmpp/__init__.py +++ b/sleekxmpp/__init__.py @@ -1,13 +1,12 @@ #!/usr/bin/python2.5 """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2010 Nathanael C. Fritz - This file is part of SleekXMPP. + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. - See the file license.txt for copying permission. + See the file license.txt for copying permission. """ -from __future__ import absolute_import, unicode_literals from . basexmpp import basexmpp from xml.etree import cElementTree as ET from . xmlstream.xmlstream import XMLStream @@ -27,10 +26,15 @@ import sys import random import copy from . import plugins +from xml.etree.cElementTree import tostring +from xml.etree.cElementTree import Element +from cStringIO import StringIO + #from . import stanza srvsupport = True try: import dns.resolver + import dns.rdatatype except ImportError: srvsupport = False @@ -53,12 +57,14 @@ class ClientXMPP(basexmpp, XMLStream): self.plugin_config = plugin_config self.escape_quotes = escape_quotes self.set_jid(jid) + self.server = None + self.port = 5222 # not used if DNS SRV is used self.plugin_whitelist = plugin_whitelist self.auto_reconnect = True self.srvsupport = srvsupport self.password = password self.registered_features = [] - self.stream_header = """<stream:stream to='%s' xmlns:stream='http://etherx.jabber.org/streams' xmlns='%s' version='1.0'>""" % (self.server,self.default_ns) + self.stream_header = """<stream:stream to='%s' xmlns:stream='http://etherx.jabber.org/streams' xmlns='%s' version='1.0'>""" % (self.domain,self.default_ns) self.stream_footer = "</stream:stream>" #self.map_namespace('http://etherx.jabber.org/streams', 'stream') #self.map_namespace('jabber:client', '') @@ -66,8 +72,16 @@ class ClientXMPP(basexmpp, XMLStream): #TODO: Use stream state here self.authenticated = False self.sessionstarted = False - self.registerHandler(Callback('Stream Features', MatchXPath('{http://etherx.jabber.org/streams}features'), self._handleStreamFeatures, thread=True)) - self.registerHandler(Callback('Roster Update', MatchXPath('{%s}iq/{jabber:iq:roster}query' % self.default_ns), self._handleRoster, thread=True)) + self.bound = False + self.bindfail = False + self.digest_auth_started = False + XMLStream.registerHandler(self, Callback('Stream Features', MatchXPath('{http://etherx.jabber.org/streams}features'), self._handleStreamFeatures, thread=True)) + XMLStream.registerHandler(self, Callback('Roster Update', MatchXPath('{%s}iq/{jabber:iq:roster}query' % self.default_ns), self._handleRoster, thread=True)) + #SASL Auth handlers + basexmpp.add_handler(self, "<challenge xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_sasl_digest_md5_auth, instream=True) + basexmpp.add_handler(self, "<response xmlns='urn:ietf:params:xml:ns:xmpp-sasl'/>", self.handler_sasl_digest_md5_auth_fail, instream=True) + basexmpp.add_handler(self, "<success xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_auth_success, instream=True) + basexmpp.add_handler(self, "<failure xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_auth_fail, instream=True) #self.registerHandler(Callback('Roster Update', MatchXMLMask("<presence xmlns='%s' type='subscribe' />" % self.default_ns), self._handlePresenceSubscribe, thread=True)) self.registerFeature("<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls' />", self.handler_starttls, True) self.registerFeature("<mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_sasl_auth, True) @@ -87,12 +101,18 @@ class ClientXMPP(basexmpp, XMLStream): def get(self, key, default): return self.plugin.get(key, default) - def connect(self, address=tuple()): + def connect(self, host=None, port=None): """Connect to the Jabber Server. Attempts SRV lookup, and if it fails, uses the JID server.""" - if not address or len(address) < 2: + + if self.state['connected']: return True + + if host: + self.server = host + if port is None: port = self.port + else: if not self.srvsupport: - logging.debug("Did not supply (address, port) to connect to and no SRV support is installed (http://www.dnspython.org). Continuing to attempt connection, using server hostname from JID.") + logging.debug("Did not supply (address, port) to connect to and no SRV support is installed (http://www.dnspython.org). Continuing to attempt connection, using domain from JID.") else: logging.debug("Since no address is supplied, attempting SRV lookup.") try: @@ -113,12 +133,19 @@ class ClientXMPP(basexmpp, XMLStream): picked = random.randint(0, intmax) for priority in priorities: if picked <= priority: - address = addresses[priority] + (host,port) = addresses[priority] break - if not address: + # if SRV lookup was successful, we aren't using a particular server. + self.server = None + + if not host: # if all else fails take server from JID. - address = (self.server, 5222) - result = XMLStream.connect(self, address[0], address[1], use_tls=True) + (host,port) = (self.domain, self.port) + self.server = None + + logging.debug('Attempting connection to %s:%d', host, port ) + #TODO option to not use TLS? + result = XMLStream.connect(self, host, port, use_tls=True) if result: self.event("connected") else: @@ -129,12 +156,12 @@ class ClientXMPP(basexmpp, XMLStream): # overriding reconnect and disconnect so that we can get some events # should events be part of or required by xmlstream? Maybe that would be cleaner def reconnect(self): - logging.info("Reconnecting") - self.event("disconnected") - XMLStream.reconnect(self) + self.disconnect(reconnect=True) - def disconnect(self, init=True, close=False, reconnect=False): + def disconnect(self, reconnect=False): self.event("disconnected") + self.authenticated = False + self.sessionstarted = False XMLStream.disconnect(self, reconnect) def registerFeature(self, mask, pointer, breaker = False): @@ -155,6 +182,7 @@ class ClientXMPP(basexmpp, XMLStream): self._handleRoster(iq, request=True) def _handleStreamFeatures(self, features): + logging.debug('handling stream features') self.features = [] for sub in features.xml: self.features.append(sub.tag) @@ -162,13 +190,17 @@ class ClientXMPP(basexmpp, XMLStream): for feature in self.registered_features: if feature[0].match(subelement): #if self.maskcmp(subelement, feature[0], True): + # This calls the feature handler & optionally breaks if feature[1](subelement) and feature[2]: #if breaker, don't continue return True def handler_starttls(self, xml): + logging.debug( 'TLS start handler; SSL support: %s', self.ssl_support ) if not self.authenticated and self.ssl_support: - self.add_handler("<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls' />", self.handler_tls_start, instream=True) - self.sendXML(xml) + _stanza = "<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls' />" + if not self.event_handlers.get(_stanza,None): # don't add handler > once + self.add_handler( _stanza, self.handler_tls_start, instream=True ) + self.sendPriorityRaw(self.tostring(xml)) return True else: logging.warning("The module tlslite is required in to some servers, and has not been found.") @@ -183,17 +215,17 @@ class ClientXMPP(basexmpp, XMLStream): if '{urn:ietf:params:xml:ns:xmpp-tls}starttls' in self.features: return False logging.debug("Starting SASL Auth") - self.add_handler("<success xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_auth_success, instream=True) - self.add_handler("<failure xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", self.handler_auth_fail, instream=True) sasl_mechs = xml.findall('{urn:ietf:params:xml:ns:xmpp-sasl}mechanism') if len(sasl_mechs): for sasl_mech in sasl_mechs: self.features.append("sasl:%s" % sasl_mech.text) - if 'sasl:PLAIN' in self.features: + if 'sasl:DIGEST-MD5' in self.features: + self.sendPriorityRaw("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>""") + elif 'sasl:PLAIN' in self.features: if sys.version_info < (3,0): - self.send("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='PLAIN'>%s</auth>""" % base64.b64encode(b'\x00' + bytes(self.username) + b'\x00' + bytes(self.password)).decode('utf-8')) + self.sendPriorityRaw("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='PLAIN'>%s</auth>""" % base64.b64encode(b'\x00' + bytes(self.username) + b'\x00' + bytes(self.password)).decode('utf-8')) else: - self.send("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='PLAIN'>%s</auth>""" % base64.b64encode(b'\x00' + bytes(self.username, 'utf-8') + b'\x00' + bytes(self.password, 'utf-8')).decode('utf-8')) + self.sendPriorityRaw("""<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='PLAIN'>%s</auth>""" % base64.b64encode(b'\x00' + bytes(self.username, 'utf-8') + b'\x00' + bytes(self.password, 'utf-8')).decode('utf-8')) else: logging.error("No appropriate login method.") self.disconnect() @@ -201,13 +233,50 @@ class ClientXMPP(basexmpp, XMLStream): # self._auth_digestmd5() return True + def handler_sasl_digest_md5_auth(self, xml): + logging.debug(tostring(xml)) + logging.debug(xml) + logging.debug(type(xml).__name__) + + if self.digest_auth_started == False: + challenge = [item.split('=', 1) for item in base64.b64decode(xml.text).replace("\"", "").split(',', 6) ] + challenge = dict(challenge) + logging.debug(challenge) + + #Realm, nonce, qop should all be present + if not challenge['realm'] or not challenge['qop'] or not challenge['nonce']: + logging.error("Error during digest-md5 authentication. Challenge missing critical information. Challenge: %s" %base64.b64decode(xml.text)) + self.disconnect() + self.event("failed_auth") + return + #TODO: charset can be either UTF-8 or if not present use ISO 8859-1 defaulting for UTF-8 for now + #Compute the cnonce - a unique hex string only used in this request + cnonce = "" + for i in range(7): + cnonce+=hex(int(random.random()*65536*4096))[2:] + cnonce = base64.encodestring(cnonce)[0:-1] + a1 = md5("%s:%s:%s" % (self.username, self.domain, self.password)) + a1 = "%s:%s:%s" %(a1, challenge["nonce"], cnonce ) + a2 = "AUTHENTICATE:xmpp/%s" %self.domain + responseHash = md5digest("%s:%s:00000001:%s:auth:%s" %(md5digest(a1), challenge["nonce"], cnonce, md5digest(a2) ) ) + response = '''charset=utf-8,username="%s",realm="%s",nonce="%s",nc=00000001,cnonce="%s",digest-uri="%s",response=%s,qop=%s,''' %(self.username, self.domain, challenge["nonce"], cnonce, "xmpp/%s" % self.domain, responseHash, challenge["qop"]) + self.sendPriorityRaw("""<response xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>%s</response>""" %base64.encodestring(response)[:-1]) + else: + logging.warn("handler_sasl_digest_md5_auth called while digest_auth_started is True (has already begun)") + + def handler_sasl_digest_md5_auth_fail(self, xml): + self.digest_auth_started = False + self.handler_auth_fail(xml) + def handler_auth_success(self, xml): + logging.debug("Authentication successful.") self.authenticated = True self.features = [] raise RestartStream() def handler_auth_fail(self, xml): - logging.info("Authentication failed.") + logging.warning("Authentication failed.") + logging.debug(tostring(xml, 'utf-8')) self.disconnect() self.event("failed_auth") @@ -221,19 +290,23 @@ class ClientXMPP(basexmpp, XMLStream): response = iq.send() #response = self.send(iq, self.Iq(sid=iq['id'])) self.set_jid(response.xml.find('{urn:ietf:params:xml:ns:xmpp-bind}bind/{urn:ietf:params:xml:ns:xmpp-bind}jid').text) + self.bound = True logging.info("Node set to: %s" % self.fulljid) - if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features: + if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features or self.bindfail: logging.debug("Established Session") self.sessionstarted = True self.event("session_start") def handler_start_session(self, xml): - if self.authenticated: + if self.authenticated and self.bound: iq = self.makeIqSet(xml) response = iq.send() logging.debug("Established Session") self.sessionstarted = True self.event("session_start") + else: + #bind probably hasn't happened yet + self.bindfail = True def _handleRoster(self, iq, request=False): if iq['type'] == 'set' or (iq['type'] == 'result' and request): @@ -244,3 +317,21 @@ class ClientXMPP(basexmpp, XMLStream): if iq['type'] == 'set': self.send(self.Iq().setValues({'type': 'result', 'id': iq['id']}).enable('roster')) self.event("roster_update", iq) + +def md5(data): + try: + import hashlib + md5 = hashlib.md5(data) + except ImportError: + import md5 + md5 = md5.new(data) + return md5.digest() + +def md5digest(data): + try: + import hashlib + md5 = hashlib.md5(data) + except ImportError: + import md5 + md5 = md5.new(data) + return md5.hexdigest() diff --git a/sleekxmpp/basexmpp.py b/sleekxmpp/basexmpp.py index 907067fa..936b5d18 100644 --- a/sleekxmpp/basexmpp.py +++ b/sleekxmpp/basexmpp.py @@ -1,9 +1,9 @@ """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2010 Nathanael C. Fritz - This file is part of SleekXMPP. + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. - See the file license.txt for copying permission. + See the file license.txt for copying permission. """ from __future__ import with_statement, unicode_literals @@ -49,7 +49,7 @@ class basexmpp(object): self.resource = '' self.jid = '' self.username = '' - self.server = '' + self.domain = '' self.plugin = {} self.auto_authorize = True self.auto_subscribe = True @@ -84,28 +84,35 @@ class basexmpp(object): self.resource = self.getjidresource(jid) self.jid = self.getjidbare(jid) self.username = jid.split('@', 1)[0] - self.server = jid.split('@',1)[-1].split('/', 1)[0] + self.domain = jid.split('@',1)[-1].split('/', 1)[0] def process(self, *args, **kwargs): for idx in self.plugin: if not self.plugin[idx].post_inited: self.plugin[idx].post_init() return super(basexmpp, self).process(*args, **kwargs) - def registerPlugin(self, plugin, pconfig = {}): + def registerPlugin(self, plugin, pconfig = {}, pluginModule = None): """Register a plugin not in plugins.__init__.__all__ but in the plugins directory.""" # discover relative "path" to the plugins module from the main app, and import it. # TODO: # gross, this probably isn't necessary anymore, especially for an installed module - __import__("%s.%s" % (globals()['plugins'].__name__, plugin)) - # init the plugin class - self.plugin[plugin] = getattr(getattr(plugins, plugin), plugin)(self, pconfig) # eek - # all of this for a nice debug? sure. - xep = '' - if hasattr(self.plugin[plugin], 'xep'): - xep = "(XEP-%s) " % self.plugin[plugin].xep - logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description)) - + try: + if pluginModule: + module = __import__(pluginModule, globals(), locals(), [plugin]) + else: + module = __import__("%s.%s" % (globals()['plugins'].__name__, plugin), globals(), locals(), [plugin]) + # init the plugin class + self.plugin[plugin] = getattr(module, plugin)(self, pconfig) # eek + # all of this for a nice debug? sure. + xep = '' + if hasattr(self.plugin[plugin], 'xep'): + xep = "(XEP-%s) " % self.plugin[plugin].xep + logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description)) + except Exception, e: + logging.error("Unable to load plugin: %s" %(plugin) ) + logging.exception(e) + def register_plugins(self): """Initiates all plugins in the plugins/__init__.__all__""" if self.plugin_whitelist: diff --git a/sleekxmpp/plugins/jobs.py b/sleekxmpp/plugins/jobs.py new file mode 100644 index 00000000..bb2e2554 --- /dev/null +++ b/sleekxmpp/plugins/jobs.py @@ -0,0 +1,44 @@ +from . import base +import logging +from xml.etree import cElementTree as ET + +class jobs(base.base_plugin): + def plugin_init(self): + self.xep = 'pubsubjob' + self.description = "Job distribution over Pubsub" + + def post_init(self): + pass + #TODO add event + + def createJobNode(self, host, jid, node, config=None): + pass + + def createJob(self, host, node, jobid=None, payload=None): + return self.xmpp.plugin['xep_0060'].setItem(host, node, ((jobid, payload),)) + + def claimJob(self, host, node, jobid, ifrom=None): + return self._setState(host, node, jobid, ET.Element('{http://andyet.net/protocol/pubsubjob}claimed')) + + def unclaimJob(self, jobid): + return self._setState(host, node, jobid, ET.Element('{http://andyet.net/protocol/pubsubjob}unclaimed')) + + def finishJob(self, host, node, jobid, payload=None): + finished = ET.Element('{http://andyet.net/protocol/pubsubjob}finished') + if payload is not None: + finished.append(payload) + return self._setState(host, node, jobid, finished) + + def _setState(self, host, node, jobid, state, ifrom=None): + iq = self.xmpp.Iq() + iq['to'] = host + if ifrom: iq['from'] = ifrom + iq['type'] = 'set' + iq['psstate']['node'] = node + iq['psstate']['item'] = jobid + iq['psstate']['payload'] = state + result = iq.send() + if result is None or result['type'] != 'result': + return False + return True + diff --git a/sleekxmpp/plugins/stanza_pubsub.py b/sleekxmpp/plugins/stanza_pubsub.py index 1dd73d99..1a1526f0 100644 --- a/sleekxmpp/plugins/stanza_pubsub.py +++ b/sleekxmpp/plugins/stanza_pubsub.py @@ -10,6 +10,39 @@ def stanzaPlugin(stanza, plugin): stanza.plugin_attrib_map[plugin.plugin_attrib] = plugin stanza.plugin_tag_map["{%s}%s" % (plugin.namespace, plugin.name)] = plugin +class PubsubState(ElementBase): + namespace = 'http://jabber.org/protocol/psstate' + name = 'state' + plugin_attrib = 'psstate' + interfaces = set(('node', 'item', 'payload')) + plugin_attrib_map = {} + plugin_tag_map = {} + + def setPayload(self, value): + self.xml.append(value) + + def getPayload(self): + childs = self.xml.getchildren() + if len(childs) > 0: + return childs[0] + + def delPayload(self): + for child in self.xml.getchildren(): + self.xml.remove(child) + +stanzaPlugin(Iq, PubsubState) + +class PubsubStateEvent(ElementBase): + namespace = 'http://jabber.org/protocol/psstate#event' + name = 'event' + plugin_attrib = 'psstate_event' + intefaces = set(tuple()) + plugin_attrib_map = {} + plugin_tag_map = {} + +stanzaPlugin(Message, PubsubStateEvent) +stanzaPlugin(PubsubStateEvent, PubsubState) + class Pubsub(ElementBase): namespace = 'http://jabber.org/protocol/pubsub' name = 'pubsub' @@ -321,18 +354,6 @@ class Options(ElementBase): stanzaPlugin(Pubsub, Options) stanzaPlugin(Subscribe, Options) -#iq = Iq() -#iq['pubsub']['defaultconfig'] -#print(iq) - -#from xml.etree import cElementTree as ET -#iq = Iq() -#item = Item() -#item['payload'] = ET.Element("{http://netflint.net/p/crap}stupidshit") -#item['id'] = 'aa11bbcc' -#iq['pubsub']['items'].append(item) -#print(iq) - class OwnerAffiliations(Affiliations): namespace = 'http://jabber.org/protocol/pubsub#owner' interfaces = set(('node')) diff --git a/sleekxmpp/plugins/xep_0004.py b/sleekxmpp/plugins/xep_0004.py index 56d18929..015bd8bc 100644 --- a/sleekxmpp/plugins/xep_0004.py +++ b/sleekxmpp/plugins/xep_0004.py @@ -188,7 +188,6 @@ class Form(FieldContainer): #def getXML(self, tostring = False): def getXML(self, ftype=None): - logging.debug("creating form as %s" % ftype) if ftype: self.type = ftype form = ET.Element('{jabber:x:data}x') diff --git a/sleekxmpp/plugins/xep_0030.py b/sleekxmpp/plugins/xep_0030.py index 5432dd56..6a31d243 100644 --- a/sleekxmpp/plugins/xep_0030.py +++ b/sleekxmpp/plugins/xep_0030.py @@ -1,25 +1,184 @@ """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2007 Nathanael C. Fritz - This file is part of SleekXMPP. - - SleekXMPP is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - SleekXMPP is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with SleekXMPP; if not, write to the Free Software - Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz, Lance J.T. Stout + This file is part of SleekXMPP. + + See the file license.txt for copying permissio """ -from . import base + import logging -from xml.etree import cElementTree as ET +from . import base +from .. xmlstream.handler.callback import Callback +from .. xmlstream.matcher.xpath import MatchXPath +from .. xmlstream.stanzabase import ElementBase, ET, JID +from .. stanza.iq import Iq + +class DiscoInfo(ElementBase): + namespace = 'http://jabber.org/protocol/disco#info' + name = 'query' + plugin_attrib = 'disco_info' + interfaces = set(('node', 'features', 'identities')) + + def getFeatures(self): + features = [] + featuresXML = self.xml.findall('{%s}feature' % self.namespace) + for feature in featuresXML: + features.append(feature.attrib['var']) + return features + + def setFeatures(self, features): + self.delFeatures() + for name in features: + self.addFeature(name) + + def delFeatures(self): + featuresXML = self.xml.findall('{%s}feature' % self.namespace) + for feature in featuresXML: + self.xml.remove(feature) + + def addFeature(self, feature): + featureXML = ET.Element('{%s}feature' % self.namespace, + {'var': feature}) + self.xml.append(featureXML) + + def delFeature(self, feature): + featuresXML = self.xml.findall('{%s}feature' % self.namespace) + for featureXML in featuresXML: + if featureXML.attrib['var'] == feature: + self.xml.remove(featureXML) + + def getIdentities(self): + ids = [] + idsXML = self.xml.findall('{%s}identity' % self.namespace) + for idXML in idsXML: + idData = (idXML.attrib['category'], + idXML.attrib['type'], + idXML.attrib.get('name', '')) + ids.append(idData) + return ids + + def setIdentities(self, ids): + self.delIdentities() + for idData in ids: + self.addIdentity(*idData) + + def delIdentities(self): + idsXML = self.xml.findall('{%s}identity' % self.namespace) + for idXML in idsXML: + self.xml.remove(idXML) + + def addIdentity(self, category, id_type, name=''): + idXML = ET.Element('{%s}identity' % self.namespace, + {'category': category, + 'type': id_type, + 'name': name}) + self.xml.append(idXML) + + def delIdentity(self, category, id_type, name=''): + idsXML = self.xml.findall('{%s}identity' % self.namespace) + for idXML in idsXML: + idData = (idXML.attrib['category'], + idXML.attrib['type']) + delId = (category, id_type) + if idData == delId: + self.xml.remove(idXML) + + +class DiscoItems(ElementBase): + namespace = 'http://jabber.org/protocol/disco#items' + name = 'query' + plugin_attrib = 'disco_items' + interfaces = set(('node', 'items')) + + def getItems(self): + items = [] + itemsXML = self.xml.findall('{%s}item' % self.namespace) + for item in itemsXML: + itemData = (item.attrib['jid'], + item.attrib.get('node'), + item.attrib.get('name')) + items.append(itemData) + return items + + def setItems(self, items): + self.delItems() + for item in items: + self.addItem(*item) + + def delItems(self): + itemsXML = self.xml.findall('{%s}item' % self.namespace) + for item in itemsXML: + self.xml.remove(item) + + def addItem(self, jid, node='', name=''): + itemXML = ET.Element('{%s}item' % self.namespace, {'jid': jid}) + if name: + itemXML.attrib['name'] = name + if node: + itemXML.attrib['node'] = node + self.xml.append(itemXML) + + def delItem(self, jid, node=''): + itemsXML = self.xml.findall('{%s}item' % self.namespace) + for itemXML in itemsXML: + itemData = (itemXML.attrib['jid'], + itemXML.attrib.get('node', '')) + itemDel = (jid, node) + if itemData == itemDel: + self.xml.remove(itemXML) + + +class DiscoNode(object): + """ + Collection object for grouping info and item information + into nodes. + """ + def __init__(self, name): + self.name = name + self.info = DiscoInfo() + self.items = DiscoItems() + + # This is a bit like poor man's inheritance, but + # to simplify adding information to the node we + # map node functions to either the info or items + # stanza objects. + # + # We don't want to make DiscoNode inherit from + # DiscoInfo and DiscoItems because DiscoNode is + # not an actual stanza, and doing so would create + # confusion and potential bugs. + + self._map(self.items, 'items', ['get', 'set', 'del']) + self._map(self.items, 'item', ['add', 'del']) + self._map(self.info, 'identities', ['get', 'set', 'del']) + self._map(self.info, 'identity', ['add', 'del']) + self._map(self.info, 'features', ['get', 'set', 'del']) + self._map(self.info, 'feature', ['add', 'del']) + + def isEmpty(self): + """ + Test if the node contains any information. Useful for + determining if a node can be deleted. + """ + ids = self.getIdentities() + features = self.getFeatures() + items = self.getItems() + + if not ids and not features and not items: + return True + return False + + def _map(self, obj, interface, access): + """ + Map functions of the form obj.accessInterface + to self.accessInterface for each given access type. + """ + interface = interface.title() + for access_type in access: + method = access_type + interface + if hasattr(obj, method): + setattr(self, method, getattr(obj, method)) + class xep_0030(base.base_plugin): """ @@ -29,85 +188,137 @@ class xep_0030(base.base_plugin): def plugin_init(self): self.xep = '0030' self.description = 'Service Discovery' - self.features = {'main': ['http://jabber.org/protocol/disco#info', 'http://jabber.org/protocol/disco#items']} - self.identities = {'main': [{'category': 'client', 'type': 'pc', 'name': 'SleekXMPP'}]} - self.items = {'main': []} - self.xmpp.add_handler("<iq type='get' xmlns='%s'><query xmlns='http://jabber.org/protocol/disco#info' /></iq>" % self.xmpp.default_ns, self.info_handler) - self.xmpp.add_handler("<iq type='get' xmlns='%s'><query xmlns='http://jabber.org/protocol/disco#items' /></iq>" % self.xmpp.default_ns, self.item_handler) + + self.xmpp.registerHandler( + Callback('Disco Items', + MatchXPath('{%s}iq/{%s}query' % (self.xmpp.default_ns, + DiscoItems.namespace)), + self.handle_item_query)) + + self.xmpp.registerHandler( + Callback('Disco Info', + MatchXPath('{%s}iq/{%s}query' % (self.xmpp.default_ns, + DiscoInfo.namespace)), + self.handle_info_query)) + + self.xmpp.stanzaPlugin(Iq, DiscoInfo) + self.xmpp.stanzaPlugin(Iq, DiscoItems) + + self.xmpp.add_event_handler('disco_items_request', self.handle_disco_items) + self.xmpp.add_event_handler('disco_info_request', self.handle_disco_info) + + self.nodes = {'main': DiscoNode('main')} + + def add_node(self, node): + if node not in self.nodes: + self.nodes[node] = DiscoNode(node) + + def del_node(self, node): + if node in self.nodes: + del self.nodes[node] + + def handle_item_query(self, iq): + if iq['type'] == 'get': + logging.debug("Items requested by %s" % iq['from']) + self.xmpp.event('disco_items_request', iq) + elif iq['type'] == 'result': + logging.debug("Items result from %s" % iq['from']) + self.xmpp.event('disco_items', iq) + + def handle_info_query(self, iq): + if iq['type'] == 'get': + logging.debug("Info requested by %s" % iq['from']) + self.xmpp.event('disco_info_request', iq) + elif iq['type'] == 'result': + logging.debug("Info result from %s" % iq['from']) + self.xmpp.event('disco_info', iq) + + def handle_disco_info(self, iq, forwarded=False): + """ + A default handler for disco#info requests. If another + handler is registered, this one will defer and not run. + """ + handlers = self.xmpp.event_handlers['disco_info_request'] + if not forwarded and len(handlers) > 1: + return + + node_name = iq['disco_info']['node'] + if not node_name: + node_name = 'main' + + logging.debug("Using default handler for disco#info on node '%s'." % node_name) + + if node_name in self.nodes: + node = self.nodes[node_name] + iq.reply().setPayload(node.info.xml).send() + else: + logging.debug("Node %s requested, but does not exist." % node_name) + iq.reply().error().setPayload(iq['disco_info'].xml) + iq['error']['code'] = '404' + iq['error']['type'] = 'cancel' + iq['error']['condition'] = 'item-not-found' + iq.send() + + def handle_disco_items(self, iq, forwarded=False): + """ + A default handler for disco#items requests. If another + handler is registered, this one will defer and not run. + + If this handler is called by your own custom handler with + forwarded set to True, then it will run as normal. + """ + handlers = self.xmpp.event_handlers['disco_items_request'] + if not forwarded and len(handlers) > 1: + return + + node_name = iq['disco_items']['node'] + if not node_name: + node_name = 'main' + + logging.debug("Using default handler for disco#items on node '%s'." % node_name) + + if node_name in self.nodes: + node = self.nodes[node_name] + iq.reply().setPayload(node.items.xml).send() + else: + logging.debug("Node %s requested, but does not exist." % node_name) + iq.reply().error().setPayload(iq['disco_items'].xml) + iq['error']['code'] = '404' + iq['error']['type'] = 'cancel' + iq['error']['condition'] = 'item-not-found' + iq.send() + + # Older interface methods for backwards compatibility + + def getInfo(self, jid, node=''): + iq = self.xmpp.Iq() + iq['type'] = 'get' + iq['to'] = jid + iq['from'] = self.xmpp.fulljid + iq['disco_info']['node'] = node + iq.send() + + def getItems(self, jid, node=''): + iq = self.xmpp.Iq() + iq['type'] = 'get' + iq['to'] = jid + iq['from'] = self.xmpp.fulljid + iq['disco_items']['node'] = node + iq.send() def add_feature(self, feature, node='main'): - if not node in self.features: - self.features[node] = [] - self.features[node].append(feature) - - def add_identity(self, category=None, itype=None, name=None, node='main'): - if not node in self.identities: - self.identities[node] = [] - self.identities[node].append({'category': category, 'type': itype, 'name': name}) + self.add_node(node) + self.nodes[node].addFeature(feature) - def add_item(self, jid=None, name=None, node='main', subnode=''): - if not node in self.items: - self.items[node] = [] - self.items[node].append({'jid': jid, 'name': name, 'node': subnode}) - - def info_handler(self, xml): - logging.debug("Info request from %s" % xml.get('from', '')) - iq = self.xmpp.makeIqResult(xml.get('id', self.xmpp.getNewId())) - iq.attrib['from'] = xml.get('to') - iq.attrib['to'] = xml.get('from', self.xmpp.server) - query = xml.find('{http://jabber.org/protocol/disco#info}query') - node = query.get('node', 'main') - for identity in self.identities.get(node, []): - idxml = ET.Element('identity') - for attrib in identity: - if identity[attrib]: - idxml.attrib[attrib] = identity[attrib] - query.append(idxml) - for feature in self.features.get(node, []): - featxml = ET.Element('feature') - featxml.attrib['var'] = feature - query.append(featxml) - iq.append(query) - #print ET.tostring(iq) - self.xmpp.send(iq) - - def item_handler(self, xml): - logging.debug("Item request from %s" % xml.get('from', '')) - iq = self.xmpp.makeIqResult(xml.get('id', self.xmpp.getNewId())) - iq.attrib['from'] = xml.get('to') - iq.attrib['to'] = xml.get('from', self.xmpp.server) - query = self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#items').find('{http://jabber.org/protocol/disco#items}query') - node = xml.find('{http://jabber.org/protocol/disco#items}query').get('node', 'main') - for item in self.items.get(node, []): - itemxml = ET.Element('item') - itemxml.attrib = item - if itemxml.attrib['jid'] is None: - itemxml.attrib['jid'] = xml.get('to') - query.append(itemxml) - self.xmpp.send(iq) + def add_identity(self, category='', itype='', name='', node='main'): + self.add_node(node) + self.nodes[node].addIdentity(category=category, + id_type=itype, + name=name) - def getItems(self, jid, node=None): - iq = self.xmpp.makeIqGet() - iq.attrib['from'] = self.xmpp.fulljid - iq.attrib['to'] = jid - self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#items') - if node: - iq.find('{http://jabber.org/protocol/disco#items}query').attrib['node'] = node - return iq.send() - - def getInfo(self, jid, node=None): - iq = self.xmpp.makeIqGet() - iq.attrib['from'] = self.xmpp.fulljid - iq.attrib['to'] = jid - self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#info') - if node: - iq.find('{http://jabber.org/protocol/disco#info}query').attrib['node'] = node - return iq.send() - - def parseInfo(self, xml): - result = {'identity': {}, 'feature': []} - for identity in xml.findall('{http://jabber.org/protocol/disco#info}query/{{http://jabber.org/protocol/disco#info}identity'): - result['identity'][identity['name']] = identity.attrib - for feature in xml.findall('{http://jabber.org/protocol/disco#info}query/{{http://jabber.org/protocol/disco#info}feature'): - result['feature'].append(feature.get('var', '__unknown__')) - return result + def add_item(self, jid=None, name='', node='main', subnode=''): + self.add_node(node) + self.add_node(subnode) + if jid is None: + jid = self.xmpp.fulljid + self.nodes[node].addItem(jid=jid, name=name, node=subnode) diff --git a/sleekxmpp/plugins/xep_0060.py b/sleekxmpp/plugins/xep_0060.py index 44a70e9a..bff158a0 100644 --- a/sleekxmpp/plugins/xep_0060.py +++ b/sleekxmpp/plugins/xep_0060.py @@ -14,12 +14,14 @@ class xep_0060(base.base_plugin): self.xep = '0060' self.description = 'Publish-Subscribe' - def create_node(self, jid, node, config=None, collection=False): + def create_node(self, jid, node, config=None, collection=False, ntype=None): pubsub = ET.Element('{http://jabber.org/protocol/pubsub}pubsub') create = ET.Element('create') create.set('node', node) pubsub.append(create) configure = ET.Element('configure') + if collection: + ntype = 'collection' #if config is None: # submitform = self.xmpp.plugin['xep_0004'].makeForm('submit') #else: @@ -29,11 +31,11 @@ class xep_0060(base.base_plugin): submitform.field['FORM_TYPE'].setValue('http://jabber.org/protocol/pubsub#node_config') else: submitform.addField('FORM_TYPE', 'hidden', value='http://jabber.org/protocol/pubsub#node_config') - if collection: + if ntype: if 'pubsub#node_type' in submitform.field: - submitform.field['pubsub#node_type'].setValue('collection') + submitform.field['pubsub#node_type'].setValue(ntype) else: - submitform.addField('pubsub#node_type', value='collection') + submitform.addField('pubsub#node_type', value=ntype) else: if 'pubsub#node_type' in submitform.field: submitform.field['pubsub#node_type'].setValue('leaf') diff --git a/sleekxmpp/stanza/error.py b/sleekxmpp/stanza/error.py index f87b6490..3346ceb2 100644 --- a/sleekxmpp/stanza/error.py +++ b/sleekxmpp/stanza/error.py @@ -1,9 +1,9 @@ """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2010 Nathanael C. Fritz - This file is part of SleekXMPP. + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. - See the file license.txt for copying permission. + See the file license.txt for copying permission. """ from .. xmlstream.stanzabase import ElementBase, ET @@ -11,7 +11,7 @@ class Error(ElementBase): namespace = 'jabber:client' name = 'error' plugin_attrib = 'error' - conditions = set(('bad-request', 'conflict', 'feature-not-implemented', 'forbidden', 'gone', 'item-not-found', 'jid-malformed', 'not-acceptable', 'not-allowed', 'not-authorized', 'payment-required', 'recipient-unavailable', 'redirect', 'registration-required', 'remote-server-not-found', 'remote-server-timeout', 'service-unavailable', 'subscription-required', 'undefined-condition', 'unexpected-request')) + conditions = set(('bad-request', 'conflict', 'feature-not-implemented', 'forbidden', 'gone', 'internal-server-error', 'item-not-found', 'jid-malformed', 'not-acceptable', 'not-allowed', 'not-authorized', 'payment-required', 'recipient-unavailable', 'redirect', 'registration-required', 'remote-server-not-found', 'remote-server-timeout', 'resource-constraint', 'service-unavailable', 'subscription-required', 'undefined-condition', 'unexpected-request')) interfaces = set(('code', 'condition', 'text', 'type')) types = set(('cancel', 'continue', 'modify', 'auth', 'wait')) sub_interfaces = set(('text',)) diff --git a/sleekxmpp/stanza/iq.py b/sleekxmpp/stanza/iq.py index ded7515f..26f09268 100644 --- a/sleekxmpp/stanza/iq.py +++ b/sleekxmpp/stanza/iq.py @@ -1,9 +1,9 @@ """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2010 Nathanael C. Fritz - This file is part of SleekXMPP. + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. - See the file license.txt for copying permission. + See the file license.txt for copying permission. """ from .. xmlstream.stanzabase import StanzaBase from xml.etree import cElementTree as ET @@ -67,11 +67,11 @@ class Iq(RootStanza): self.xml.remove(child) return self - def send(self, block=True, timeout=10): + def send(self, block=True, timeout=10, priority=False): if block and self['type'] in ('get', 'set'): waitfor = Waiter('IqWait_%s' % self['id'], MatcherId(self['id'])) self.stream.registerHandler(waitfor) - StanzaBase.send(self) + StanzaBase.send(self, priority) return waitfor.wait(timeout) else: - return StanzaBase.send(self) + return StanzaBase.send(self, priority) diff --git a/sleekxmpp/xmlstream/handler/base.py b/sleekxmpp/xmlstream/handler/base.py index 5d55f4ee..a44edf0e 100644 --- a/sleekxmpp/xmlstream/handler/base.py +++ b/sleekxmpp/xmlstream/handler/base.py @@ -18,7 +18,7 @@ class BaseHandler(object): def match(self, xml): return self._matcher.match(xml) - def prerun(self, payload): + def prerun(self, payload): # what's the point of this if the payload is called again in run?? self._payload = payload def run(self, payload): diff --git a/sleekxmpp/xmlstream/handler/callback.py b/sleekxmpp/xmlstream/handler/callback.py index 49cfa14d..ea5acb5b 100644 --- a/sleekxmpp/xmlstream/handler/callback.py +++ b/sleekxmpp/xmlstream/handler/callback.py @@ -17,13 +17,15 @@ class Callback(base.BaseHandler): self._once = once self._instream = instream - def prerun(self, payload): + def prerun(self, payload): # prerun actually calls run?!? WTF! Then it gets run AGAIN! base.BaseHandler.prerun(self, payload) if self._instream: + logging.debug('callback "%s" prerun', self.name) self.run(payload, True) def run(self, payload, instream=False): if not self._instream or instream: + logging.debug('callback "%s" run', self.name) base.BaseHandler.run(self, payload) #if self._thread: # x = threading.Thread(name="Callback_%s" % self.name, target=self._pointer, args=(payload,)) diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py new file mode 100644 index 00000000..40aaf695 --- /dev/null +++ b/sleekxmpp/xmlstream/scheduler.py @@ -0,0 +1,87 @@ +try: + import queue +except ImportError: + import Queue as queue +import time +import threading +import logging + +class Task(object): + """Task object for the Scheduler class""" + def __init__(self, name, seconds, callback, args=None, kwargs=None, repeat=False, qpointer=None): + self.name = name + self.seconds = seconds + self.callback = callback + self.args = args or tuple() + self.kwargs = kwargs or {} + self.repeat = repeat + self.next = time.time() + self.seconds + self.qpointer = qpointer + + def run(self): + if self.qpointer is not None: + self.qpointer.put(('schedule', self.callback, self.args)) + else: + self.callback(*self.args, **self.kwargs) + self.reset() + return self.repeat + + def reset(self): + self.next = time.time() + self.seconds + +class Scheduler(object): + """Threaded scheduler that allows for updates mid-execution unlike http://docs.python.org/library/sched.html#module-sched""" + def __init__(self, parentqueue=None): + self.addq = queue.Queue() + self.schedule = [] + self.thread = None + self.run = False + self.parentqueue = parentqueue + + def process(self, threaded=True): + if threaded: + self.thread = threading.Thread(name='shedulerprocess', target=self._process) + self.thread.start() + else: + self._process() + + def _process(self): + self.run = True + while self.run: + try: + wait = 1 + updated = False + if self.schedule: + wait = self.schedule[0].next - time.time() + try: + if wait <= 0.0: + newtask = self.addq.get(False) + else: + newtask = self.addq.get(True, wait) + except queue.Empty: + cleanup = [] + for task in self.schedule: + if time.time() >= task.next: + updated = True + if not task.run(): + cleanup.append(task) + else: + break + for task in cleanup: + x = self.schedule.pop(self.schedule.index(task)) + else: + updated = True + self.schedule.append(newtask) + finally: + if updated: self.schedule = sorted(self.schedule, key=lambda task: task.next) + except KeyboardInterrupt: + self.run = False + logging.debug("Quitting Scheduler thread") + if self.parentqueue is not None: + self.parentqueue.put(('quit', None, None)) + + def add(self, name, seconds, callback, args=None, kwargs=None, repeat=False, qpointer=None): + self.addq.put(Task(name, seconds, callback, args, kwargs, repeat, qpointer)) + + def quit(self): + self.run = False diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py index 3f3f5e08..34513807 100644 --- a/sleekxmpp/xmlstream/stanzabase.py +++ b/sleekxmpp/xmlstream/stanzabase.py @@ -1,9 +1,9 @@ """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2010 Nathanael C. Fritz - This file is part of SleekXMPP. + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. - See the file license.txt for copying permission. + See the file license.txt for copying permission. """ from xml.etree import cElementTree as ET import logging @@ -78,6 +78,9 @@ class ElementBase(tostring.ToString): def __iter__(self): self.idx = 0 return self + + def __bool__(self): + return True def __next__(self): self.idx += 1 @@ -319,6 +322,8 @@ class StanzaBase(ElementBase): def __init__(self, stream=None, xml=None, stype=None, sto=None, sfrom=None, sid=None): self.stream = stream + if stream is not None: + self.namespace = stream.default_ns ElementBase.__init__(self, xml) if stype is not None: self['type'] = stype @@ -326,8 +331,6 @@ class StanzaBase(ElementBase): self['to'] = sto if sfrom is not None: self['from'] = sfrom - if stream is not None: - self.namespace = stream.default_ns self.tag = "{%s}%s" % (self.namespace, self.name) def setType(self, value): @@ -380,6 +383,7 @@ class StanzaBase(ElementBase): def exception(self, e): logging.error(traceback.format_tb(e)) - def send(self): - self.stream.sendRaw(self.__str__()) - + def send(self, priority=False): + if priority: self.stream.sendPriorityRaw(self.__str__()) + else: self.stream.sendRaw(self.__str__()) + diff --git a/sleekxmpp/xmlstream/statemachine.py b/sleekxmpp/xmlstream/statemachine.py index fb7d1508..67b514a2 100644 --- a/sleekxmpp/xmlstream/statemachine.py +++ b/sleekxmpp/xmlstream/statemachine.py @@ -7,53 +7,228 @@ """ from __future__ import with_statement import threading +import time +import logging + class StateMachine(object): - def __init__(self, states=[], groups=[]): - self.lock = threading.Lock() - self.__state = {} - self.__default_state = {} - self.__group = {} + def __init__(self, states=[]): + self.lock = threading.Condition(threading.RLock()) + self.__states= [] self.addStates(states) - self.addGroups(groups) + self.__default_state = self.__states[0] + self.__current_state = self.__default_state def addStates(self, states): with self.lock: for state in states: - if state in self.__state or state in self.__group: - raise IndexError("The state or group '%s' is already in the StateMachine." % state) - self.__state[state] = states[state] - self.__default_state[state] = states[state] + if state in self.__states: + raise IndexError("The state '%s' is already in the StateMachine." % state) + self.__states.append( state ) - def addGroups(self, groups): - with self.lock: - for gstate in groups: - if gstate in self.__state or gstate in self.__group: - raise IndexError("The key or group '%s' is already in the StateMachine." % gstate) - for state in groups[gstate]: - if state in self.__state: - raise IndexError("The group %s contains a key %s which is not set in the StateMachine." % (gstate, state)) - self.__group[gstate] = groups[gstate] - - def set(self, state, status): + + def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ): + ''' + Transition from the given `from_state` to the given `to_state`. + This method will return `True` if the state machine is now in `to_state`. It + will return `False` if a timeout occurred the transition did not occur. + If `wait` is 0 (the default,) this method returns immediately if the state machine + is not in `from_state`. + + If you want the thread to block and transition once the state machine to enters + `from_state`, set `wait` to a non-negative value. Note there is no 'block + indefinitely' flag since this leads to deadlock. If you want to wait indefinitely, + choose a reasonable value for `wait` (e.g. 20 seconds) and do so in a while loop like so: + + :: + + while not thread_should_exit and not state_machine.transition('disconnected', 'connecting', wait=20 ): + pass # timeout will occur every 20s unless transition occurs + if thread_should_exit: return + # perform actions here after successful transition + + This allows the thread to be responsive by setting `thread_should_exit=True`. + + The optional `func` argument allows the user to pass a callable operation which occurs + within the context of the state transition (e.g. while the state machine is locked.) + If `func` returns a True value, the transition will occur. If `func` returns a non- + True value or if an exception is thrown, the transition will not occur. Any thrown + exception is not caught by the state machine and is the caller's responsibility to handle. + If `func` completes normally, this method will return the value returned by `func.` If + values for `args` and `kwargs` are provided, they are expanded and passed like so: + `func( *args, **kwargs )`. + ''' + + return self.transition_any( (from_state,), to_state, wait=wait, + func=func, args=args, kwargs=kwargs ) + + + def transition_any(self, from_states, to_state, wait=0.0, func=None, args=[], kwargs={} ): + ''' + Transition from any of the given `from_states` to the given `to_state`. + ''' + + if not (isinstance(from_states,tuple) or isinstance(from_states,list)): + raise ValueError( "from_states should be a list or tuple" ) + + for state in from_states: + if not state in self.__states: + raise ValueError( "StateMachine does not contain from_state %s." % state ) + if not to_state in self.__states: + raise ValueError( "StateMachine does not contain to_state %s." % to_state ) + with self.lock: - if state in self.__state: - self.__state[state] = bool(status) + start = time.time() + while not self.__current_state in from_states: + # detect timeout: + if time.time() >= start + wait: return False + self.lock.wait(wait) + + if self.__current_state in from_states: # should always be True due to lock + + return_val = True + # Note that func might throw an exception, but that's OK, it aborts the transition + if func is not None: return_val = func(*args,**kwargs) + + # some 'false' value returned from func, + # indicating that transition should not occur: + if not return_val: return return_val + + logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state) + self.__current_state = to_state + self.lock.notify_all() + return return_val # some 'true' value returned by func or True if func was None else: - raise KeyError("StateMachine does not contain state %s." % state) - - def __getitem__(self, key): - if key in self.__group: - for state in self.__group[key]: - if not self.__state[state]: - return False - return True - return self.__state[key] + logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" ) + return False + + + def transition_ctx(self, from_state, to_state, wait=0.0): + ''' + Use the state machine as a context manager. The transition occurs on /exit/ from + the `with` context, so long as no exception is thrown. For example: + + :: + + with state_machine.transition_ctx('one','two', wait=5) as locked: + if locked: + # the state machine is currently locked in state 'one', and will + # transition to 'two' when the 'with' statement ends, so long as + # no exception is thrown. + print 'Currently locked in state one: %s' % state_machine['one'] + + else: + # The 'wait' timed out, and no lock has been acquired + print 'Timed out before entering state "one"' + + print 'Since no exception was thrown, we are now in state "two": %s' % state_machine['two'] + + + The other main difference between this method and `transition()` is that the + state machine is locked for the duration of the `with` statement. Normally, + after a `transition()` occurs, the state machine is immediately unlocked and + available to another thread to call `transition()` again. + ''' + + if not from_state in self.__states: + raise ValueError( "StateMachine does not contain from_state %s." % from_state ) + if not to_state in self.__states: + raise ValueError( "StateMachine does not contain to_state %s." % to_state ) + + return _StateCtx(self, from_state, to_state, wait) + - def __getattr__(self, attr): - return self.__getitem__(attr) + def ensure(self, state, wait=0.0): + ''' + Ensure the state machine is currently in `state`, or wait until it enters `state`. + ''' + return self.ensure_any( (state,), wait=wait ) + + + def ensure_any(self, states, wait=0.0): + ''' + Ensure we are currently in one of the given `states` + ''' + if not (isinstance(states,tuple) or isinstance(states,list)): + raise ValueError('states arg should be a tuple or list') + + for state in states: + if not state in self.__states: + raise ValueError( "StateMachine does not contain state '%s'" % state ) + + with self.lock: + start = time.time() + while not self.__current_state in states: + # detect timeout: + if time.time() >= start + wait: return False + self.lock.wait(wait) + return self.__current_state in states # should always be True due to lock + def reset(self): - self.__state = self.__default_state + # TODO need to lock before calling this? + self.transition(self.__current_state, self._default_state) + + + def _set_state(self, state): #unsynchronized, only call internally after lock is acquired + self.__current_state = state + return state + + + def current_state(self): + ''' + Return the current state name. + ''' + return self.__current_state + + + def __getitem__(self, state): + ''' + Non-blocking, non-synchronized test to determine if we are in the given state. + Use `StateMachine.ensure(state)` to wait until the machine enters a certain state. + ''' + return self.__current_state == state + + def __str__(self): + return "".join(( "StateMachine(", ','.join(self.__states), "): ", self.__current_state )) + + + +class _StateCtx: + + def __init__( self, state_machine, from_state, to_state, wait ): + self.state_machine = state_machine + self.from_state = from_state + self.to_state = to_state + self.wait = wait + self._timeout = False + + def __enter__(self): + self.state_machine.lock.acquire() + start = time.time() + while not self.state_machine[ self.from_state ]: + # detect timeout: + if time.time() >= start + self.wait: + logging.debug('StateMachine timeout while waiting for state: %s', self.from_state ) + self._timeout = True # to indicate we should not transition + return False + self.state_machine.lock.wait(self.wait) + + logging.debug('StateMachine entered context in state: %s', + self.state_machine.current_state() ) + return True + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_val is not None: + logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s", + self.state_machine.current_state(), exc_type.__name__, exc_val ) + elif not self._timeout: + logging.debug(' ==== TRANSITION %s -> %s', + self.state_machine.current_state(), self.to_state) + self.state_machine._set_state( self.to_state ) + + self.state_machine.lock.notify_all() + self.state_machine.lock.release() + return False # re-raise any exception diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index 025884b7..a8bcac00 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -1,9 +1,9 @@ """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2010 Nathanael C. Fritz - This file is part of SleekXMPP. + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. - See the file license.txt for copying permission. + See the file license.txt for copying permission. """ from __future__ import with_statement, unicode_literals @@ -16,12 +16,14 @@ from . stanzabase import StanzaBase from xml.etree import cElementTree from xml.parsers import expat import logging +import random import socket import threading import time import traceback import types import xml.sax.saxutils +from . import scheduler HANDLER_THREADS = 1 @@ -45,6 +47,10 @@ class CloseStream(Exception): stanza_extensions = {} +RECONNECT_MAX_DELAY = 3600 +RECONNECT_QUIESCE_FACTOR = 1.6180339887498948 # Phi +RECONNECT_QUIESCE_JITTER = 0.11962656472 # molar Planck constant times c, joule meter/mole + class XMLStream(object): "A connection manager with XML events." @@ -52,8 +58,9 @@ class XMLStream(object): global ssl_support self.ssl_support = ssl_support self.escape_quotes = escape_quotes - self.state = statemachine.StateMachine() - self.state.addStates({'connected':False, 'is client':False, 'ssl':False, 'tls':False, 'reconnect':True, 'processing':False, 'disconnecting':False}) #set initial states + self.state = statemachine.StateMachine(('disconnected','connecting', + 'connected')) + self.should_reconnect = True self.setSocket(socket) self.address = (host, int(port)) @@ -69,12 +76,14 @@ class XMLStream(object): self.filesocket = None self.use_ssl = False self.use_tls = False + self.ca_certs=None self.stream_header = "<stream>" self.stream_footer = "</stream>" self.eventqueue = queue.Queue() - self.sendqueue = queue.Queue() + self.sendqueue = queue.PriorityQueue() + self.scheduler = scheduler.Scheduler(self.eventqueue) self.namespace_map = {} @@ -83,45 +92,77 @@ class XMLStream(object): def setSocket(self, socket): "Set the socket" self.socket = socket - if socket is not None: + if socket is not None and self.state.transition('disconnected','connecting'): self.filesocket = socket.makefile('rb', 0) # ElementTree.iterparse requires a file. 0 buffer files have to be binary - self.state.set('connected', True) - + self.state.transition('connecting','connected') def setFileSocket(self, filesocket): self.filesocket = filesocket - def connect(self, host='', port=0, use_ssl=False, use_tls=True): - "Link to connectTCP" - return self.connectTCP(host, port, use_ssl, use_tls) + def connect(self, host='', port=0, use_ssl=None, use_tls=None): + "Establish a socket connection to the given XMPP server." + + if not self.state.transition('disconnected','connected', + func=self.connectTCP, args=[host, port, use_ssl, use_tls] ): + + if self.state['connected']: logging.debug('Already connected') + else: logging.warning("Connection failed" ) + return False + + logging.debug('Connection complete.') + return True + + # TODO currently a caller can't distinguish between "connection failed" and + # "we're already trying to connect from another thread" def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True): "Connect and create socket" - while reattempt and not self.state['connected']: - if host and port: - self.address = (host, int(port)) - if use_ssl is not None: - self.use_ssl = use_ssl - if use_tls is not None: - self.use_tls = use_tls - self.state.set('is client', True) - if sys.version_info < (3, 0): - self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM) - else: - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.settimeout(None) - if self.use_ssl and self.ssl_support: - logging.debug("Socket Wrapped for SSL") - self.socket = ssl.wrap_socket(self.socket) + + # Note that this is thread-safe by merit of being called solely from connect() which + # holds the state lock. + + delay = 1.0 # reconnection delay + while self.run: + logging.debug('connecting....') try: + if host and port: + self.address = (host, int(port)) + if use_ssl is not None: + self.use_ssl = use_ssl + if use_tls is not None: + # TODO this variable doesn't seem to be used for anything! + self.use_tls = use_tls + if sys.version_info < (3, 0): + self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM) + else: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.settimeout(None) #10) + + if self.use_ssl and self.ssl_support: + logging.debug("Socket Wrapped for SSL") + self.socket = ssl.wrap_socket(self.socket,ca_certs=self.ca_certs) + self.socket.connect(self.address) - #self.filesocket = self.socket.makefile('rb', 0) self.filesocket = self.socket.makefile('rb', 0) - self.state.set('connected', True) + return True + except socket.error as serr: - logging.error("Could not connect. Socket Error #%s: %s" % (serr.errno, serr.strerror)) - time.sleep(1) + logging.exception("Socket Error #%s: %s", serr.errno, serr.strerror) + if not reattempt: return False + except: + logging.exception("Connection error") + if not reattempt: return False + + # quiesce if rconnection fails: + # This algorithm based loosely on Twisted internet.protocol + # http://twistedmatrix.com/trac/browser/trunk/twisted/internet/protocol.py#L310 + delay = min(delay * RECONNECT_QUIESCE_FACTOR, RECONNECT_MAX_DELAY) + delay = random.normalvariate(delay, delay * RECONNECT_QUIESCE_JITTER) + logging.debug('Waiting %fs until next reconnect attempt...', delay) + time.sleep(delay) + + def connectUnix(self, filepath): "Connect to Unix file and create socket" @@ -130,14 +171,19 @@ class XMLStream(object): "Handshakes for TLS" if self.ssl_support: logging.info("Negotiating TLS") - self.realsocket = self.socket - self.socket = ssl.wrap_socket(self.socket, ssl_version=ssl.PROTOCOL_TLSv1, do_handshake_on_connect=False) +# self.realsocket = self.socket # NOT USED + self.socket = ssl.wrap_socket(self.socket, + ssl_version=ssl.PROTOCOL_TLSv1, + do_handshake_on_connect=False, + ca_certs=self.ca_certs) self.socket.do_handshake() if sys.version_info < (3,0): from . filesocket import filesocket self.filesocket = filesocket(self.socket) else: self.filesocket = self.socket.makefile('rb', 0) + + logging.debug("TLS negotitation successful") return True else: logging.warning("Tried to enable TLS, but ssl module not found.") @@ -145,67 +191,56 @@ class XMLStream(object): raise RestartStream() def process(self, threaded=True): + self.scheduler.process(threaded=True) + self.run = True for t in range(0, HANDLER_THREADS): - self.__thread['eventhandle%s' % t] = threading.Thread(name='eventhandle%s' % t, target=self._eventRunner) - self.__thread['eventhandle%s' % t].start() - self.__thread['sendthread'] = threading.Thread(name='sendthread', target=self._sendThread) - self.__thread['sendthread'].start() + th = threading.Thread(name='eventhandle%s' % t, target=self._eventRunner) + th.setDaemon(True) + self.__thread['eventhandle%s' % t] = th + th.start() + th = threading.Thread(name='sendthread', target=self._sendThread) + th.setDaemon(True) + self.__thread['sendthread'] = th + th.start() if threaded: - self.__thread['process'] = threading.Thread(name='process', target=self._process) - self.__thread['process'].start() + th = threading.Thread(name='process', target=self._process) + th.setDaemon(True) + self.__thread['process'] = th + th.start() else: self._process() - def schedule(self, seconds, handler, args=None): - threading.Timer(seconds, handler, args).start() + def schedule(self, name, seconds, callback, args=None, kwargs=None, repeat=False): + self.scheduler.add(name, seconds, callback, args, kwargs, repeat, qpointer=self.eventqueue) def _process(self): "Start processing the socket." - firstrun = True - while self.run and (firstrun or self.state['reconnect']): - self.state.set('processing', True) - firstrun = False + logging.debug('Process thread starting...') + while self.run: + if not self.state.ensure('connected',wait=2): continue try: - if self.state['is client']: - self.sendRaw(self.stream_header) - while self.run and self.__readXML(): - if self.state['is client']: - self.sendRaw(self.stream_header) - except KeyboardInterrupt: - logging.debug("Keyboard Escape Detected") - self.state.set('processing', False) - self.state.set('reconnect', False) - self.disconnect() - self.run = False - self.eventqueue.put(('quit', None, None)) - return + self.sendPriorityRaw(self.stream_header) + while self.run and self.__readXML(): pass + except socket.timeout: + logging.debug('socket rcv timeout') + pass except CloseStream: - return - except SystemExit: + # TODO warn that the listener thread is exiting!!! + pass + except RestartStream: + logging.debug("Restarting stream...") + continue # DON'T re-initialize the stream -- this exception is sent + # specifically when we've initialized TLS and need to re-send the <stream> header. + except (KeyboardInterrupt, SystemExit): + logging.debug("System interrupt detected") + self.shutdown() self.eventqueue.put(('quit', None, None)) - return - except socket.error: - if not self.state.reconnect: - return - else: - self.state.set('processing', False) - traceback.print_exc() - self.disconnect(reconnect=True) except: - if not self.state.reconnect: - return - else: - self.state.set('processing', False) - traceback.print_exc() + logging.exception('Unexpected error in RCV thread') + if self.should_reconnect: self.disconnect(reconnect=True) - if self.state['reconnect']: - self.reconnect() - self.state.set('processing', False) - self.eventqueue.put(('quit', None, None)) - #self.__thread['readXML'] = threading.Thread(name='readXML', target=self.__readXML) - #self.__thread['readXML'].start() - #self.__thread['spawnEvents'] = threading.Thread(name='spawnEvents', target=self.__spawnEvents) - #self.__thread['spawnEvents'].start() + + logging.debug('Quitting Process thread') def __readXML(self): "Parses the incoming stream, adding to xmlin queue as it goes" @@ -218,82 +253,94 @@ class XMLStream(object): if edepth == 0: # and xmlobj.tag.split('}', 1)[-1] == self.basetag: if event == b'start': root = xmlobj + logging.debug('handling start stream') self.start_stream_handler(root) if event == b'end': edepth += -1 if edepth == 0 and event == b'end': - self.disconnect(reconnect=self.state['reconnect']) + # what is this case exactly? Premature EOF? + logging.debug("Ending readXML loop") return False elif edepth == 1: #self.xmlin.put(xmlobj) - try: - self.__spawnEvent(xmlobj) - except RestartStream: - return True - except CloseStream: - return False - if root: - root.clear() + self.__spawnEvent(xmlobj) + if root: root.clear() if event == b'start': edepth += 1 + logging.debug("Exiting readXML loop") + return False def _sendThread(self): + logging.debug('send thread starting...') while self.run: - data = self.sendqueue.get(True) - logging.debug("SEND: %s" % data) + if not self.state.ensure('connected',wait=2): continue + + data = None try: - self.socket.send(data.encode('utf-8')) - #self.socket.send(bytes(data, "utf-8")) - #except socket.error,(errno, strerror): + data = self.sendqueue.get(True,5)[1] + logging.debug("SEND: %s" % data) + self.socket.sendall(data.encode('utf-8')) + except queue.Empty: +# logging.debug('Nothing on send queue') + pass + except socket.timeout: + # this is to prevent a thread blocked indefinitely + logging.debug('timeout sending packet data') except: logging.warning("Failed to send %s" % data) - self.state.set('connected', False) - if self.state.reconnect: - logging.error("Disconnected. Socket Error.") - traceback.print_exc() + logging.exception("Socket error in SEND thread") + # TODO it's somewhat unsafe for the sender thread to assume it can just + # re-intitialize the connection, since the receiver thread could be doing + # the same thing concurrently. Oops! The safer option would be to throw + # some sort of event that could be handled by a common thread or the reader + # thread to perform reconnect and then re-initialize the handler threads as well. + if self.should_reconnect: self.disconnect(reconnect=True) def sendRaw(self, data): - self.sendqueue.put(data) + self.sendqueue.put((1, data)) + return True + + def sendPriorityRaw(self, data): + self.sendqueue.put((0, data)) return True def disconnect(self, reconnect=False): - self.state.set('reconnect', reconnect) - if self.state['disconnecting']: + if not self.state.transition('connected','disconnected'): + logging.warning("Already disconnected.") return - if not self.state['reconnect']: - logging.debug("Disconnecting...") - self.state.set('disconnecting', True) - self.run = False - if self.state['connected']: - self.sendRaw(self.stream_footer) - time.sleep(1) - #send end of stream - #wait for end of stream back + logging.debug("Disconnecting...") + self.sendPriorityRaw(self.stream_footer) + time.sleep(5) + #send end of stream + #wait for end of stream back try: +# self.socket.shutdown(socket.SHUT_RDWR) self.socket.close() + except socket.error as (errno,strerror): + logging.exception("Error while disconnecting. Socket Error #%s: %s" % (errno, strerror)) + try: self.filesocket.close() - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error as serr: - #logging.warning("Error while disconnecting. Socket Error #%s: %s" % (errno, strerror)) - #thread.exit_thread() - pass - if self.state['processing']: - #raise CloseStream - pass - - def reconnect(self): - self.state.set('tls',False) - self.state.set('ssl',False) - time.sleep(1) - self.connect() + except socket.error as (errno,strerror): + logging.exception("Error closing filesocket.") + + if reconnect: self.connect() + def shutdown(self): + ''' + Disconnects and shuts down all event threads. + ''' + self.disconnect() + self.run = False + self.scheduler.run = False + def incoming_filter(self, xmlobj): return xmlobj - + def __spawnEvent(self, xmlobj): "watching xmlOut and processes handlers" #convert XML into Stanza + # TODO surround this log statement with an if, it's expensive logging.debug("RECV: %s" % cElementTree.tostring(xmlobj)) xmlobj = self.incoming_filter(xmlobj) stanza = None @@ -305,48 +352,54 @@ class XMLStream(object): if stanza is None: stanza = StanzaBase(self, xmlobj) unhandled = True + # TODO inefficient linear search; performance might be improved by hashtable lookup for handler in self.__handlers: if handler.match(stanza): + logging.debug('matched stanza to handler %s', handler.name) handler.prerun(stanza) self.eventqueue.put(('stanza', handler, stanza)) - if handler.checkDelete(): self.__handlers.pop(self.__handlers.index(handler)) + if handler.checkDelete(): + logging.debug('deleting callback %s', handler.name) + self.__handlers.pop(self.__handlers.index(handler)) unhandled = False if unhandled: stanza.unhandled() #loop through handlers and test match #spawn threads as necessary, call handlers, sending Stanza - + def _eventRunner(self): logging.debug("Loading event runner") while self.run: try: event = self.eventqueue.get(True, timeout=5) except queue.Empty: +# logging.debug('Nothing on event queue') event = None if event is not None: etype = event[0] handler = event[1] args = event[2:] - #etype, handler, *args = event #python 3.x way + #etype, handler, *args = event #python 3.x way if etype == 'stanza': try: handler.run(args[0]) except Exception as e: - traceback.print_exc() + logging.exception("Exception in event handler") args[0].exception(e) elif etype == 'sched': try: + #handler(*args[0]) handler.run(*args) except: logging.error(traceback.format_exc()) elif etype == 'quit': logging.debug("Quitting eventRunner thread") return False - + def registerHandler(self, handler, before=None, after=None): "Add handler with matcher class and parameters." self.__handlers.append(handler) - + def removeHandler(self, name): "Removes the handler." idx = 0 @@ -432,4 +485,4 @@ class XMLStream(object): def start_stream_handler(self, xml): """Meant to be overridden""" - pass + logging.warn("No start stream handler has been implemented.") |