diff options
Diffstat (limited to 'sleekxmpp/xmlstream')
-rw-r--r-- | sleekxmpp/xmlstream/cert.py | 4 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/filesocket.py | 9 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/matcher/xmlmask.py | 71 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/matcher/xpath.py | 37 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/resolver.py | 7 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/scheduler.py | 34 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/stanzabase.py | 54 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/tostring.py | 35 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 67 |
9 files changed, 158 insertions, 160 deletions
diff --git a/sleekxmpp/xmlstream/cert.py b/sleekxmpp/xmlstream/cert.py index fa12f794..71146f36 100644 --- a/sleekxmpp/xmlstream/cert.py +++ b/sleekxmpp/xmlstream/cert.py @@ -1,6 +1,10 @@ import logging from datetime import datetime, timedelta +# Make a call to strptime before starting threads to +# prevent thread safety issues. +datetime.strptime('1970-01-01 12:00:00', "%Y-%m-%d %H:%M:%S") + try: from pyasn1.codec.der import decoder, encoder diff --git a/sleekxmpp/xmlstream/filesocket.py b/sleekxmpp/xmlstream/filesocket.py index d4537998..53b83bc7 100644 --- a/sleekxmpp/xmlstream/filesocket.py +++ b/sleekxmpp/xmlstream/filesocket.py @@ -13,6 +13,7 @@ """ from socket import _fileobject +import errno import socket @@ -29,7 +30,13 @@ class FileSocket(_fileobject): """Read data from the socket as if it were a file.""" if self._sock is None: return None - data = self._sock.recv(size) + while True: + try: + data = self._sock.recv(size) + break + except socket.error as serr: + if serr.errno != errno.EINTR: + raise if data is not None: return data diff --git a/sleekxmpp/xmlstream/matcher/xmlmask.py b/sleekxmpp/xmlstream/matcher/xmlmask.py index a0568f08..56f728e1 100644 --- a/sleekxmpp/xmlstream/matcher/xmlmask.py +++ b/sleekxmpp/xmlstream/matcher/xmlmask.py @@ -14,12 +14,6 @@ from sleekxmpp.xmlstream.stanzabase import ET from sleekxmpp.xmlstream.matcher.base import MatcherBase -# Flag indicating if the builtin XPath matcher should be used, which -# uses namespaces, or a custom matcher that ignores namespaces. -# Changing this will affect ALL XMLMask matchers. -IGNORE_NS = False - - log = logging.getLogger(__name__) @@ -39,19 +33,15 @@ class MatchXMLMask(MatcherBase): :class:`~sleekxmpp.xmlstream.matcher.stanzapath.StanzaPath` should be used instead. - The use of namespaces in the mask comparison is controlled by - ``IGNORE_NS``. Setting ``IGNORE_NS`` to ``True`` will disable namespace - based matching for ALL XMLMask matchers. - :param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML object or XML string to use as a mask. """ - def __init__(self, criteria): + def __init__(self, criteria, default_ns='jabber:client'): MatcherBase.__init__(self, criteria) if isinstance(criteria, str): self._criteria = ET.fromstring(self._criteria) - self.default_ns = 'jabber:client' + self.default_ns = default_ns def setDefaultNS(self, ns): """Set the default namespace to use during comparisons. @@ -84,8 +74,6 @@ class MatchXMLMask(MatcherBase): do not have a specified namespace. Defaults to ``"__no_ns__"``. """ - use_ns = not IGNORE_NS - if source is None: # If the element was not found. May happend during recursive calls. return False @@ -96,17 +84,10 @@ class MatchXMLMask(MatcherBase): mask = ET.fromstring(mask) except ExpatError: log.warning("Expat error: %s\nIn parsing: %s", '', mask) - if not use_ns: - # Compare the element without using namespaces. - source_tag = source.tag.split('}', 1)[-1] - mask_tag = mask.tag.split('}', 1)[-1] - if source_tag != mask_tag: - return False - else: - # Compare the element using namespaces - mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag) - if source.tag not in [mask.tag, mask_ns_tag]: - return False + + mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag) + if source.tag not in [mask.tag, mask_ns_tag]: + return False # If the mask includes text, compare it. if mask.text and source.text and \ @@ -122,37 +103,15 @@ class MatchXMLMask(MatcherBase): # Recursively check subelements. matched_elements = {} for subelement in mask: - if use_ns: - matched = False - for other in source.findall(subelement.tag): - matched_elements[other] = False - if self._mask_cmp(other, subelement, use_ns): - if not matched_elements.get(other, False): - matched_elements[other] = True - matched = True - if not matched: - return False - else: - if not self._mask_cmp(self._get_child(source, subelement.tag), - subelement, use_ns): - return False + matched = False + for other in source.findall(subelement.tag): + matched_elements[other] = False + if self._mask_cmp(other, subelement, use_ns): + if not matched_elements.get(other, False): + matched_elements[other] = True + matched = True + if not matched: + return False # Everything matches. return True - - def _get_child(self, xml, tag): - """Return a child element given its tag, ignoring namespace values. - - Returns ``None`` if the child was not found. - - :param xml: The :class:`~xml.etree.ElementTree.Element` XML object - to search for the given child tag. - :param tag: The name of the subelement to find. - """ - tag = tag.split('}')[-1] - try: - children = [c.tag.split('}')[-1] for c in xml] - index = children.index(tag) - except ValueError: - return None - return list(xml)[index] diff --git a/sleekxmpp/xmlstream/matcher/xpath.py b/sleekxmpp/xmlstream/matcher/xpath.py index 3f03e68e..f3d28429 100644 --- a/sleekxmpp/xmlstream/matcher/xpath.py +++ b/sleekxmpp/xmlstream/matcher/xpath.py @@ -9,16 +9,10 @@ :license: MIT, see LICENSE for more details """ -from sleekxmpp.xmlstream.stanzabase import ET +from sleekxmpp.xmlstream.stanzabase import ET, fix_ns from sleekxmpp.xmlstream.matcher.base import MatcherBase -# Flag indicating if the builtin XPath matcher should be used, which -# uses namespaces, or a custom matcher that ignores namespaces. -# Changing this will affect ALL XPath matchers. -IGNORE_NS = False - - class MatchXPath(MatcherBase): """ @@ -38,6 +32,9 @@ class MatchXPath(MatcherBase): expressions will be matched without using namespaces. """ + def __init__(self, criteria): + self._criteria = fix_ns(criteria) + def match(self, xml): """ Compare a stanza's XML contents to an XPath expression. @@ -59,28 +56,4 @@ class MatchXPath(MatcherBase): x = ET.Element('x') x.append(xml) - if not IGNORE_NS: - # Use builtin, namespace respecting, XPath matcher. - if x.find(self._criteria) is not None: - return True - return False - else: - # Remove namespaces from the XPath expression. - criteria = [] - for ns_block in self._criteria.split('{'): - criteria.extend(ns_block.split('}')[-1].split('/')) - - # Walk the XPath expression. - xml = x - for tag in criteria: - if not tag: - # Skip empty tag name artifacts from the cleanup phase. - continue - - children = [c.tag.split('}')[-1] for c in xml] - try: - index = children.index(tag) - except ValueError: - return False - xml = list(xml)[index] - return True + return x.find(self._criteria) is not None diff --git a/sleekxmpp/xmlstream/resolver.py b/sleekxmpp/xmlstream/resolver.py index 394daa64..6f26797f 100644 --- a/sleekxmpp/xmlstream/resolver.py +++ b/sleekxmpp/xmlstream/resolver.py @@ -113,7 +113,7 @@ def resolve(host, port=None, service=None, proto='tcp', if hasattr(socket, 'inet_pton'): ipv6 = socket.inet_pton(socket.AF_INET6, host) yield (host, host, port) - except socket.error: + except (socket.error, ValueError): pass # If no service was provided, then we can just do A/AAAA lookups on the @@ -202,11 +202,14 @@ def get_AAAA(host, resolver=None): # If not using dnspython, attempt lookup using the OS level # getaddrinfo() method. if resolver is None: + if not socket.has_ipv6: + log.debug("Unable to query %s for AAAA records: IPv6 is not supported", host) + return [] try: recs = socket.getaddrinfo(host, None, socket.AF_INET6, socket.SOCK_STREAM) return [rec[4][0] for rec in recs] - except socket.gaierror: + except (OSError, socket.gaierror): log.debug("DNS: Error retreiving AAAA address " + \ "info for %s." % host) return [] diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py index b3e50983..e6fae37a 100644 --- a/sleekxmpp/xmlstream/scheduler.py +++ b/sleekxmpp/xmlstream/scheduler.py @@ -20,6 +20,11 @@ import itertools from sleekxmpp.util import Queue, QueueEmpty +#: The time in seconds to wait for events from the event queue, and also the +#: time between checks for the process stop signal. +WAIT_TIMEOUT = 1.0 + + log = logging.getLogger(__name__) @@ -76,7 +81,7 @@ class Task(object): """ if self.qpointer is not None: self.qpointer.put(('schedule', self.callback, - self.args, self.name)) + self.args, self.kwargs, self.name)) else: self.callback(*self.args, **self.kwargs) self.reset() @@ -120,6 +125,10 @@ class Scheduler(object): #: Lock for accessing the task queue. self.schedule_lock = threading.RLock() + #: The time in seconds to wait for events from the event queue, + #: and also the time between checks for the process stop signal. + self.wait_timeout = WAIT_TIMEOUT + def process(self, threaded=True, daemon=False): """Begin accepting and processing scheduled tasks. @@ -139,24 +148,25 @@ class Scheduler(object): self.run = True try: while self.run and not self.stop.is_set(): - wait = 0.1 updated = False if self.schedule: wait = self.schedule[0].next - time.time() + else: + wait = self.wait_timeout try: if wait <= 0.0: newtask = self.addq.get(False) else: - if wait >= 3.0: - wait = 3.0 newtask = None - elapsed = 0 - while not self.stop.is_set() and \ + while self.run and \ + not self.stop.is_set() and \ newtask is None and \ - elapsed < wait: - newtask = self.addq.get(True, 0.1) - elapsed += 0.1 - except QueueEmpty: + wait > 0: + try: + newtask = self.addq.get(True, min(wait, self.wait_timeout)) + except QueueEmpty: # Nothing to add, nothing to do. Check run flags and continue waiting. + wait -= self.wait_timeout + except QueueEmpty: # Time to run some tasks, and no new tasks to add. self.schedule_lock.acquire() # select only those tasks which are to be executed now relevant = itertools.takewhile( @@ -174,11 +184,11 @@ class Scheduler(object): # only need to resort tasks if a repeated task has # been kept in the list. updated = True - else: - updated = True + else: # Add new task self.schedule_lock.acquire() if newtask is not None: self.schedule.append(newtask) + updated = True finally: if updated: self.schedule.sort(key=lambda task: task.next) diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py index 122d7eb4..97107098 100644 --- a/sleekxmpp/xmlstream/stanzabase.py +++ b/sleekxmpp/xmlstream/stanzabase.py @@ -3,7 +3,7 @@ sleekxmpp.xmlstream.stanzabase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - This module implements a wrapper layer for XML objects + module implements a wrapper layer for XML objects that allows them to be treated like dictionaries. Part of SleekXMPP: The Sleek XMPP Library @@ -141,7 +141,7 @@ def multifactory(stanza, plugin_attrib): parent.loaded_plugins.remove(plugin_attrib) try: parent.xml.remove(self.xml) - except: + except ValueError: pass else: for stanza in list(res): @@ -192,7 +192,7 @@ def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''): for element in elements: if element: # Skip empty entry artifacts from splitting. - if propagate_ns: + if propagate_ns and element[0] != '*': tag = '{%s}%s' % (namespace, element) else: tag = element @@ -596,31 +596,39 @@ class ElementBase(object): iterable_interfaces = [p.plugin_attrib for \ p in self.plugin_iterables] + if 'lang' in values: + self['lang'] = values['lang'] + + if 'substanzas' in values: + # Remove existing substanzas + for stanza in self.iterables: + try: + self.xml.remove(stanza.xml) + except ValueError: + pass + self.iterables = [] + + # Add new substanzas + for subdict in values['substanzas']: + if '__childtag__' in subdict: + for subclass in self.plugin_iterables: + child_tag = "{%s}%s" % (subclass.namespace, + subclass.name) + if subdict['__childtag__'] == child_tag: + sub = subclass(parent=self) + sub.values = subdict + self.iterables.append(sub) + for interface, value in values.items(): full_interface = interface interface_lang = ('%s|' % interface).split('|') interface = interface_lang[0] lang = interface_lang[1] or self.get_lang() - if interface == 'substanzas': - # Remove existing substanzas - for stanza in self.iterables: - self.xml.remove(stanza.xml) - self.iterables = [] - - # Add new substanzas - for subdict in value: - if '__childtag__' in subdict: - for subclass in self.plugin_iterables: - child_tag = "{%s}%s" % (subclass.namespace, - subclass.name) - if subdict['__childtag__'] == child_tag: - sub = subclass(parent=self) - sub.values = subdict - self.iterables.append(sub) - break - elif interface == 'lang': - self[interface] = value + if interface == 'lang': + continue + elif interface == 'substanzas': + continue elif interface in self.interfaces: self[full_interface] = value elif interface in self.plugin_attrib_map: @@ -866,7 +874,7 @@ class ElementBase(object): self.loaded_plugins.remove(attrib) try: self.xml.remove(plugin.xml) - except: + except ValueError: pass return self diff --git a/sleekxmpp/xmlstream/tostring.py b/sleekxmpp/xmlstream/tostring.py index 08d7ad02..c49abd3e 100644 --- a/sleekxmpp/xmlstream/tostring.py +++ b/sleekxmpp/xmlstream/tostring.py @@ -24,8 +24,8 @@ if sys.version_info < (3, 0): XML_NS = 'http://www.w3.org/XML/1998/namespace' -def tostring(xml=None, xmlns='', stream=None, - outbuffer='', top_level=False, open_only=False): +def tostring(xml=None, xmlns='', stream=None, outbuffer='', + top_level=False, open_only=False, namespaces=None): """Serialize an XML object to a Unicode string. If an outer xmlns is provided using ``xmlns``, then the current element's @@ -41,7 +41,8 @@ def tostring(xml=None, xmlns='', stream=None, during recursive calls. :param bool top_level: Indicates that the element is the outermost element. - + :param set namespaces: Track which namespaces are in active use so + that new ones can be declared when needed. :type xml: :py:class:`~xml.etree.ElementTree.Element` :type stream: :class:`~sleekxmpp.xmlstream.xmlstream.XMLStream` @@ -63,6 +64,7 @@ def tostring(xml=None, xmlns='', stream=None, default_ns = '' stream_ns = '' use_cdata = False + if stream: default_ns = stream.default_ns stream_ns = stream.stream_ns @@ -82,6 +84,7 @@ def tostring(xml=None, xmlns='', stream=None, output.append(namespace) # Output escaped attribute values. + new_namespaces = set() for attrib, value in xml.attrib.items(): value = escape(value, use_cdata) if '}' not in attrib: @@ -89,14 +92,20 @@ def tostring(xml=None, xmlns='', stream=None, else: attrib_ns = attrib.split('}')[0][1:] attrib = attrib.split('}')[1] - if stream and attrib_ns in stream.namespace_map: + if attrib_ns == XML_NS: + output.append(' xml:%s="%s"' % (attrib, value)) + elif stream and attrib_ns in stream.namespace_map: mapped_ns = stream.namespace_map[attrib_ns] if mapped_ns: - output.append(' %s:%s="%s"' % (mapped_ns, - attrib, - value)) - elif attrib_ns == XML_NS: - output.append(' xml:%s="%s"' % (attrib, value)) + if namespaces is None: + namespaces = set() + if attrib_ns not in namespaces: + namespaces.add(attrib_ns) + new_namespaces.add(attrib_ns) + output.append(' xmlns:%s="%s"' % ( + mapped_ns, attrib_ns)) + output.append(' %s:%s="%s"' % ( + mapped_ns, attrib, value)) if open_only: # Only output the opening tag, regardless of content. @@ -110,7 +119,8 @@ def tostring(xml=None, xmlns='', stream=None, output.append(escape(xml.text, use_cdata)) if len(xml): for child in xml: - output.append(tostring(child, tag_xmlns, stream)) + output.append(tostring(child, tag_xmlns, stream, + namespaces=namespaces)) output.append("</%s>" % tag_name) elif xml.text: # If we only have text content. @@ -121,6 +131,11 @@ def tostring(xml=None, xmlns='', stream=None, if xml.tail: # If there is additional text after the element. output.append(escape(xml.tail, use_cdata)) + for ns in new_namespaces: + # Remove namespaces introduced in this context. This is necessary + # because the namespaces object continues to be shared with other + # contexts. + namespaces.remove(ns) return ''.join(output) diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index bea6e88f..8242a127 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -26,6 +26,7 @@ import time import random import weakref import uuid +import errno from xml.parsers.expat import ExpatError @@ -49,7 +50,7 @@ RESPONSE_TIMEOUT = 30 #: The time in seconds to wait for events from the event queue, and also the #: time between checks for the process stop signal. -WAIT_TIMEOUT = 0.1 +WAIT_TIMEOUT = 1.0 #: The number of threads to use to handle XML stream events. This is not the #: same as the number of custom event handling threads. @@ -461,10 +462,10 @@ class XMLStream(object): time.sleep(0.1) elapsed += 0.1 except KeyboardInterrupt: - self.stop.set() + self.set_stop() return False except SystemExit: - self.stop.set() + self.set_stop() return False if self.default_domain: @@ -550,7 +551,7 @@ class XMLStream(object): cert.verify(self._expected_server_name, self._der_cert) except cert.CertificateError as err: if not self.event_handled('ssl_invalid_cert'): - log.error(err.message) + log.error(err) self.disconnect(send_close=False) else: self.event('ssl_invalid_cert', @@ -559,8 +560,7 @@ class XMLStream(object): self.set_socket(self.socket, ignore=True) #this event is where you should set your application state - self.event("connected", direct=True) - self.reconnect_delay = 1.0 + self.event('connected', direct=True) return True except (Socket.error, ssl.SSLError) as serr: error_msg = "Could not connect to %s:%s. Socket Error #%s: %s" @@ -611,6 +611,7 @@ class XMLStream(object): lines = resp.split('\r\n') if '200' not in lines[0]: self.event('proxy_error', resp) + self.event('connection_failed', direct=True) log.error('Proxy Error: %s', lines[0]) return False @@ -706,7 +707,7 @@ class XMLStream(object): self.stream_end_event.set() if not self.auto_reconnect: - self.stop.set() + self.set_stop() if self._disconnect_wait_for_threads: self._wait_for_threads() @@ -718,12 +719,12 @@ class XMLStream(object): self.event('socket_error', serr, direct=True) finally: #clear your application state - self.event("disconnected", direct=True) + self.event('disconnected', direct=True) return True def abort(self): self.session_started_event.clear() - self.stop.set() + self.set_stop() if self._disconnect_wait_for_threads: self._wait_for_threads() try: @@ -859,7 +860,7 @@ class XMLStream(object): cert.verify(self._expected_server_name, self._der_cert) except cert.CertificateError as err: if not self.event_handled('ssl_invalid_cert'): - log.error(err.message) + log.error(err) self.disconnect(self.auto_reconnect, send_close=False) else: self.event('ssl_invalid_cert', pem_cert, direct=True) @@ -1016,9 +1017,13 @@ class XMLStream(object): # and handler classes here. if name is None: - name = 'add_handler_%s' % self.getNewId() - self.registerHandler(XMLCallback(name, MatchXMLMask(mask), pointer, - once=disposable, instream=instream)) + name = 'add_handler_%s' % self.new_id() + self.register_handler( + XMLCallback(name, + MatchXMLMask(mask, self.default_ns), + pointer, + once=disposable, + instream=instream)) def register_handler(self, handler, before=None, after=None): """Add a stream event handler that will be executed when a matching @@ -1131,6 +1136,8 @@ class XMLStream(object): event queue. All event handlers will run in the same thread. """ + log.debug("Event triggered: " + name) + handlers = self.__event_handlers.get(name, []) for handler in handlers: #TODO: Data should not be copied, but should be read only, @@ -1288,6 +1295,9 @@ class XMLStream(object): try: sent += self.socket.send(data[sent:]) count += 1 + except Socket.error as serr: + if serr.errno != errno.EINTR: + raise except ssl.SSLError as serr: if tries >= self.ssl_retry_max: log.debug('SSL error: max retries reached') @@ -1350,6 +1360,13 @@ class XMLStream(object): if self.__thread_count == 0: self.__thread_cond.notify() + def set_stop(self): + self.stop.set() + + # Unlock queues + self.event_queue.put(None) + self.send_queue.put(None) + def _wait_for_threads(self): with self.__thread_cond: if self.__thread_count != 0: @@ -1493,6 +1510,10 @@ class XMLStream(object): # as handshakes. self.stream_end_event.clear() self.start_stream_handler(root) + + # We have a successful stream connection, so reset + # exponential backoff for new reconnect attempts. + self.reconnect_delay = 1.0 depth += 1 if event == b'end': depth -= 1 @@ -1618,11 +1639,7 @@ class XMLStream(object): log.debug("Loading event runner") try: while not self.stop.is_set(): - try: - wait = self.wait_timeout - event = self.event_queue.get(True, timeout=wait) - except QueueEmpty: - event = None + event = self.event_queue.get() if event is None: continue @@ -1638,10 +1655,10 @@ class XMLStream(object): log.exception(error_msg, handler.name) orig.exception(e) elif etype == 'schedule': - name = args[1] + name = args[2] try: log.debug('Scheduled event: %s: %s', name, args[0]) - handler(*args[0]) + handler(*args[0], **args[1]) except Exception as e: log.exception('Error processing scheduled task') self.exception(e) @@ -1683,14 +1700,13 @@ class XMLStream(object): while not self.stop.is_set(): while not self.stop.is_set() and \ not self.session_started_event.is_set(): - self.session_started_event.wait(timeout=0.1) + self.session_started_event.wait(timeout=0.1) # Wait for session start if self.__failed_send_stanza is not None: data = self.__failed_send_stanza self.__failed_send_stanza = None else: - try: - data = self.send_queue.get(True, 1) - except QueueEmpty: + data = self.send_queue.get() # Wait for data to send + if data is None: continue log.debug("SEND: %s", data) enc_data = data.encode('utf-8') @@ -1705,6 +1721,9 @@ class XMLStream(object): try: sent += self.socket.send(enc_data[sent:]) count += 1 + except Socket.error as serr: + if serr.errno != errno.EINTR: + raise except ssl.SSLError as serr: if tries >= self.ssl_retry_max: log.debug('SSL error: max retries reached') |