diff options
Diffstat (limited to 'sleekxmpp/xmlstream')
-rw-r--r-- | sleekxmpp/xmlstream/__init__.py | 2 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/cert.py | 18 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/filesocket.py | 11 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/handler/__init__.py | 1 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/handler/collector.py | 66 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/handler/waiter.py | 9 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/jid.py | 146 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/matcher/__init__.py | 1 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/matcher/idsender.py | 47 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/matcher/xmlmask.py | 71 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/matcher/xpath.py | 37 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/resolver.py | 79 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/scheduler.py | 70 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/stanzabase.py | 114 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/tostring.py | 83 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 295 |
16 files changed, 571 insertions, 479 deletions
diff --git a/sleekxmpp/xmlstream/__init__.py b/sleekxmpp/xmlstream/__init__.py index 67b20c56..5a1ea1be 100644 --- a/sleekxmpp/xmlstream/__init__.py +++ b/sleekxmpp/xmlstream/__init__.py @@ -6,7 +6,7 @@ See the file LICENSE for copying permission. """ -from sleekxmpp.xmlstream.jid import JID +from sleekxmpp.jid import JID from sleekxmpp.xmlstream.scheduler import Scheduler from sleekxmpp.xmlstream.stanzabase import StanzaBase, ElementBase, ET from sleekxmpp.xmlstream.stanzabase import register_stanza_plugin diff --git a/sleekxmpp/xmlstream/cert.py b/sleekxmpp/xmlstream/cert.py index 339f872d..71146f36 100644 --- a/sleekxmpp/xmlstream/cert.py +++ b/sleekxmpp/xmlstream/cert.py @@ -1,6 +1,10 @@ import logging from datetime import datetime, timedelta +# Make a call to strptime before starting threads to +# prevent thread safety issues. +datetime.strptime('1970-01-01 12:00:00', "%Y-%m-%d %H:%M:%S") + try: from pyasn1.codec.der import decoder, encoder @@ -94,7 +98,7 @@ def extract_names(raw_cert): def extract_dates(raw_cert): if not HAVE_PYASN1: - log.warning("Could not find pyasn1 module. " + \ + log.warning("Could not find pyasn1 and pyasn1_modules. " + \ "SSL certificate expiration COULD NOT BE VERIFIED.") return None, None @@ -130,7 +134,7 @@ def get_ttl(raw_cert): def verify(expected, raw_cert): if not HAVE_PYASN1: - log.warning("Could not find pyasn1 module. " + \ + log.warning("Could not find pyasn1 and pyasn1_modules. " + \ "SSL certificate COULD NOT BE VERIFIED.") return @@ -147,7 +151,10 @@ def verify(expected, raw_cert): raise CertificateError( 'Certificate has expired.') - expected_wild = expected[expected.index('.'):] + if '.' in expected: + expected_wild = expected[expected.index('.'):] + else: + expected_wild = expected expected_srv = '_xmpp-client.%s' % expected for name in cert_names['XMPPAddr']: @@ -160,7 +167,10 @@ def verify(expected, raw_cert): if name == expected: return True if name.startswith('*'): - name_wild = name[name.index('.'):] + if '.' in name: + name_wild = name[name.index('.'):] + else: + name_wild = name if expected_wild == name_wild: return True for name in cert_names['URI']: diff --git a/sleekxmpp/xmlstream/filesocket.py b/sleekxmpp/xmlstream/filesocket.py index 56554c73..53b83bc7 100644 --- a/sleekxmpp/xmlstream/filesocket.py +++ b/sleekxmpp/xmlstream/filesocket.py @@ -13,6 +13,7 @@ """ from socket import _fileobject +import errno import socket @@ -29,12 +30,18 @@ class FileSocket(_fileobject): """Read data from the socket as if it were a file.""" if self._sock is None: return None - data = self._sock.recv(size) + while True: + try: + data = self._sock.recv(size) + break + except socket.error as serr: + if serr.errno != errno.EINTR: + raise if data is not None: return data -class Socket26(socket._socketobject): +class Socket26(socket.socket): """A custom socket implementation that uses our own FileSocket class to work around issues in Python 2.6 when using sockets as files. diff --git a/sleekxmpp/xmlstream/handler/__init__.py b/sleekxmpp/xmlstream/handler/__init__.py index 7bcf0b71..83c87f01 100644 --- a/sleekxmpp/xmlstream/handler/__init__.py +++ b/sleekxmpp/xmlstream/handler/__init__.py @@ -7,6 +7,7 @@ """ from sleekxmpp.xmlstream.handler.callback import Callback +from sleekxmpp.xmlstream.handler.collector import Collector from sleekxmpp.xmlstream.handler.waiter import Waiter from sleekxmpp.xmlstream.handler.xmlcallback import XMLCallback from sleekxmpp.xmlstream.handler.xmlwaiter import XMLWaiter diff --git a/sleekxmpp/xmlstream/handler/collector.py b/sleekxmpp/xmlstream/handler/collector.py new file mode 100644 index 00000000..8f02f8c3 --- /dev/null +++ b/sleekxmpp/xmlstream/handler/collector.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" + sleekxmpp.xmlstream.handler.collector + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Part of SleekXMPP: The Sleek XMPP Library + + :copyright: (c) 2012 Nathanael C. Fritz, Lance J.T. Stout + :license: MIT, see LICENSE for more details +""" + +import logging + +from sleekxmpp.util import Queue, QueueEmpty +from sleekxmpp.xmlstream.handler.base import BaseHandler + + +log = logging.getLogger(__name__) + + +class Collector(BaseHandler): + + """ + The Collector handler allows for collecting a set of stanzas + that match a given pattern. Unlike the Waiter handler, a + Collector does not block execution, and will continue to + accumulate matching stanzas until told to stop. + + :param string name: The name of the handler. + :param matcher: A :class:`~sleekxmpp.xmlstream.matcher.base.MatcherBase` + derived object for matching stanza objects. + :param stream: The :class:`~sleekxmpp.xmlstream.xmlstream.XMLStream` + instance this handler should monitor. + """ + + def __init__(self, name, matcher, stream=None): + BaseHandler.__init__(self, name, matcher, stream=stream) + self._payload = Queue() + + def prerun(self, payload): + """Store the matched stanza when received during processing. + + :param payload: The matched + :class:`~sleekxmpp.xmlstream.stanzabase.ElementBase` object. + """ + self._payload.put(payload) + + def run(self, payload): + """Do not process this handler during the main event loop.""" + pass + + def stop(self): + """ + Stop collection of matching stanzas, and return the ones that + have been stored so far. + """ + self._destroy = True + results = [] + try: + while True: + results.append(self._payload.get(False)) + except QueueEmpty: + pass + + self.stream().remove_handler(self.name) + return results diff --git a/sleekxmpp/xmlstream/handler/waiter.py b/sleekxmpp/xmlstream/handler/waiter.py index 899df17c..66e14496 100644 --- a/sleekxmpp/xmlstream/handler/waiter.py +++ b/sleekxmpp/xmlstream/handler/waiter.py @@ -10,11 +10,8 @@ """ import logging -try: - import queue -except ImportError: - import Queue as queue +from sleekxmpp.util import Queue, QueueEmpty from sleekxmpp.xmlstream.handler.base import BaseHandler @@ -37,7 +34,7 @@ class Waiter(BaseHandler): def __init__(self, name, matcher, stream=None): BaseHandler.__init__(self, name, matcher, stream=stream) - self._payload = queue.Queue() + self._payload = Queue() def prerun(self, payload): """Store the matched stanza when received during processing. @@ -74,7 +71,7 @@ class Waiter(BaseHandler): try: stanza = self._payload.get(True, 1) break - except queue.Empty: + except QueueEmpty: elapsed_time += 1 if elapsed_time >= timeout: log.warning("Timed out waiting for %s", self.name) diff --git a/sleekxmpp/xmlstream/jid.py b/sleekxmpp/xmlstream/jid.py index 281bf4ee..2b59db47 100644 --- a/sleekxmpp/xmlstream/jid.py +++ b/sleekxmpp/xmlstream/jid.py @@ -1,145 +1,5 @@ -# -*- coding: utf-8 -*- -""" - sleekxmpp.xmlstream.jid - ~~~~~~~~~~~~~~~~~~~~~~~ +import logging - This module allows for working with Jabber IDs (JIDs) by - providing accessors for the various components of a JID. +logging.warning('Deprecated: sleekxmpp.xmlstream.jid is moving to sleekxmpp.jid') - Part of SleekXMPP: The Sleek XMPP Library - - :copyright: (c) 2011 Nathanael C. Fritz - :license: MIT, see LICENSE for more details -""" - -from __future__ import unicode_literals - - -class JID(object): - - """ - A representation of a Jabber ID, or JID. - - Each JID may have three components: a user, a domain, and an optional - resource. For example: user@domain/resource - - When a resource is not used, the JID is called a bare JID. - The JID is a full JID otherwise. - - **JID Properties:** - :jid: Alias for ``full``. - :full: The value of the full JID. - :bare: The value of the bare JID. - :user: The username portion of the JID. - :domain: The domain name portion of the JID. - :server: Alias for ``domain``. - :resource: The resource portion of the JID. - - :param string jid: A string of the form ``'[user@]domain[/resource]'``. - """ - - def __init__(self, jid): - """Initialize a new JID""" - self.reset(jid) - - def reset(self, jid): - """Start fresh from a new JID string. - - :param string jid: A string of the form ``'[user@]domain[/resource]'``. - """ - if isinstance(jid, JID): - jid = jid.full - self._full = self._jid = jid - self._domain = None - self._resource = None - self._user = None - self._bare = None - - def __getattr__(self, name): - """Handle getting the JID values, using cache if available. - - :param name: One of: user, server, domain, resource, - full, or bare. - """ - if name == 'resource': - if self._resource is None and '/' in self._jid: - self._resource = self._jid.split('/', 1)[-1] - return self._resource or "" - elif name == 'user': - if self._user is None: - if '@' in self._jid: - self._user = self._jid.split('@', 1)[0] - else: - self._user = self._user - return self._user or "" - elif name in ('server', 'domain', 'host'): - if self._domain is None: - self._domain = self._jid.split('@', 1)[-1].split('/', 1)[0] - return self._domain or "" - elif name in ('full', 'jid'): - return self._jid or "" - elif name == 'bare': - if self._bare is None: - self._bare = self._jid.split('/', 1)[0] - return self._bare or "" - - def __setattr__(self, name, value): - """Edit a JID by updating it's individual values, resetting the - generated JID in the end. - - Arguments: - name -- The name of the JID part. One of: user, domain, - server, resource, full, jid, or bare. - value -- The new value for the JID part. - """ - if name in ('resource', 'user', 'domain'): - object.__setattr__(self, "_%s" % name, value) - self.regenerate() - elif name in ('server', 'domain', 'host'): - self.domain = value - elif name in ('full', 'jid'): - self.reset(value) - self.regenerate() - elif name == 'bare': - if '@' in value: - u, d = value.split('@', 1) - object.__setattr__(self, "_user", u) - object.__setattr__(self, "_domain", d) - else: - object.__setattr__(self, "_user", '') - object.__setattr__(self, "_domain", value) - self.regenerate() - else: - object.__setattr__(self, name, value) - - def regenerate(self): - """Generate a new JID based on current values, useful after editing.""" - jid = "" - if self.user: - jid = "%s@" % self.user - jid += self.domain - if self.resource: - jid += "/%s" % self.resource - self.reset(jid) - - def __str__(self): - """Use the full JID as the string value.""" - return self.full - - def __repr__(self): - return self.full - - def __eq__(self, other): - """ - Two JIDs are considered equal if they have the same full JID value. - """ - other = JID(other) - return self.full == other.full - - def __ne__(self, other): - """Two JIDs are considered unequal if they are not equal.""" - return not self == other - - def __hash__(self): - """Hash a JID based on the string version of its full JID.""" - return hash(self.full) +from sleekxmpp.jid import JID diff --git a/sleekxmpp/xmlstream/matcher/__init__.py b/sleekxmpp/xmlstream/matcher/__init__.py index 1038d1bd..aa74c434 100644 --- a/sleekxmpp/xmlstream/matcher/__init__.py +++ b/sleekxmpp/xmlstream/matcher/__init__.py @@ -7,6 +7,7 @@ """ from sleekxmpp.xmlstream.matcher.id import MatcherId +from sleekxmpp.xmlstream.matcher.idsender import MatchIDSender from sleekxmpp.xmlstream.matcher.many import MatchMany from sleekxmpp.xmlstream.matcher.stanzapath import StanzaPath from sleekxmpp.xmlstream.matcher.xmlmask import MatchXMLMask diff --git a/sleekxmpp/xmlstream/matcher/idsender.py b/sleekxmpp/xmlstream/matcher/idsender.py new file mode 100644 index 00000000..5c2c1f51 --- /dev/null +++ b/sleekxmpp/xmlstream/matcher/idsender.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +""" + sleekxmpp.xmlstream.matcher.id + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Part of SleekXMPP: The Sleek XMPP Library + + :copyright: (c) 2011 Nathanael C. Fritz + :license: MIT, see LICENSE for more details +""" + +from sleekxmpp.xmlstream.matcher.base import MatcherBase + + +class MatchIDSender(MatcherBase): + + """ + The IDSender matcher selects stanzas that have the same stanza 'id' + interface value as the desired ID, and that the 'from' value is one + of a set of approved entities that can respond to a request. + """ + + def match(self, xml): + """Compare the given stanza's ``'id'`` attribute to the stored + ``id`` value, and verify the sender's JID. + + :param xml: The :class:`~sleekxmpp.xmlstream.stanzabase.ElementBase` + stanza to compare against. + """ + + selfjid = self._criteria['self'] + peerjid = self._criteria['peer'] + + allowed = {} + allowed[''] = True + allowed[selfjid.bare] = True + allowed[selfjid.host] = True + allowed[peerjid.full] = True + allowed[peerjid.bare] = True + allowed[peerjid.host] = True + + _from = xml['from'] + + try: + return xml['id'] == self._criteria['id'] and allowed[_from] + except KeyError: + return False diff --git a/sleekxmpp/xmlstream/matcher/xmlmask.py b/sleekxmpp/xmlstream/matcher/xmlmask.py index a0568f08..56f728e1 100644 --- a/sleekxmpp/xmlstream/matcher/xmlmask.py +++ b/sleekxmpp/xmlstream/matcher/xmlmask.py @@ -14,12 +14,6 @@ from sleekxmpp.xmlstream.stanzabase import ET from sleekxmpp.xmlstream.matcher.base import MatcherBase -# Flag indicating if the builtin XPath matcher should be used, which -# uses namespaces, or a custom matcher that ignores namespaces. -# Changing this will affect ALL XMLMask matchers. -IGNORE_NS = False - - log = logging.getLogger(__name__) @@ -39,19 +33,15 @@ class MatchXMLMask(MatcherBase): :class:`~sleekxmpp.xmlstream.matcher.stanzapath.StanzaPath` should be used instead. - The use of namespaces in the mask comparison is controlled by - ``IGNORE_NS``. Setting ``IGNORE_NS`` to ``True`` will disable namespace - based matching for ALL XMLMask matchers. - :param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML object or XML string to use as a mask. """ - def __init__(self, criteria): + def __init__(self, criteria, default_ns='jabber:client'): MatcherBase.__init__(self, criteria) if isinstance(criteria, str): self._criteria = ET.fromstring(self._criteria) - self.default_ns = 'jabber:client' + self.default_ns = default_ns def setDefaultNS(self, ns): """Set the default namespace to use during comparisons. @@ -84,8 +74,6 @@ class MatchXMLMask(MatcherBase): do not have a specified namespace. Defaults to ``"__no_ns__"``. """ - use_ns = not IGNORE_NS - if source is None: # If the element was not found. May happend during recursive calls. return False @@ -96,17 +84,10 @@ class MatchXMLMask(MatcherBase): mask = ET.fromstring(mask) except ExpatError: log.warning("Expat error: %s\nIn parsing: %s", '', mask) - if not use_ns: - # Compare the element without using namespaces. - source_tag = source.tag.split('}', 1)[-1] - mask_tag = mask.tag.split('}', 1)[-1] - if source_tag != mask_tag: - return False - else: - # Compare the element using namespaces - mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag) - if source.tag not in [mask.tag, mask_ns_tag]: - return False + + mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag) + if source.tag not in [mask.tag, mask_ns_tag]: + return False # If the mask includes text, compare it. if mask.text and source.text and \ @@ -122,37 +103,15 @@ class MatchXMLMask(MatcherBase): # Recursively check subelements. matched_elements = {} for subelement in mask: - if use_ns: - matched = False - for other in source.findall(subelement.tag): - matched_elements[other] = False - if self._mask_cmp(other, subelement, use_ns): - if not matched_elements.get(other, False): - matched_elements[other] = True - matched = True - if not matched: - return False - else: - if not self._mask_cmp(self._get_child(source, subelement.tag), - subelement, use_ns): - return False + matched = False + for other in source.findall(subelement.tag): + matched_elements[other] = False + if self._mask_cmp(other, subelement, use_ns): + if not matched_elements.get(other, False): + matched_elements[other] = True + matched = True + if not matched: + return False # Everything matches. return True - - def _get_child(self, xml, tag): - """Return a child element given its tag, ignoring namespace values. - - Returns ``None`` if the child was not found. - - :param xml: The :class:`~xml.etree.ElementTree.Element` XML object - to search for the given child tag. - :param tag: The name of the subelement to find. - """ - tag = tag.split('}')[-1] - try: - children = [c.tag.split('}')[-1] for c in xml] - index = children.index(tag) - except ValueError: - return None - return list(xml)[index] diff --git a/sleekxmpp/xmlstream/matcher/xpath.py b/sleekxmpp/xmlstream/matcher/xpath.py index 3f03e68e..f3d28429 100644 --- a/sleekxmpp/xmlstream/matcher/xpath.py +++ b/sleekxmpp/xmlstream/matcher/xpath.py @@ -9,16 +9,10 @@ :license: MIT, see LICENSE for more details """ -from sleekxmpp.xmlstream.stanzabase import ET +from sleekxmpp.xmlstream.stanzabase import ET, fix_ns from sleekxmpp.xmlstream.matcher.base import MatcherBase -# Flag indicating if the builtin XPath matcher should be used, which -# uses namespaces, or a custom matcher that ignores namespaces. -# Changing this will affect ALL XPath matchers. -IGNORE_NS = False - - class MatchXPath(MatcherBase): """ @@ -38,6 +32,9 @@ class MatchXPath(MatcherBase): expressions will be matched without using namespaces. """ + def __init__(self, criteria): + self._criteria = fix_ns(criteria) + def match(self, xml): """ Compare a stanza's XML contents to an XPath expression. @@ -59,28 +56,4 @@ class MatchXPath(MatcherBase): x = ET.Element('x') x.append(xml) - if not IGNORE_NS: - # Use builtin, namespace respecting, XPath matcher. - if x.find(self._criteria) is not None: - return True - return False - else: - # Remove namespaces from the XPath expression. - criteria = [] - for ns_block in self._criteria.split('{'): - criteria.extend(ns_block.split('}')[-1].split('/')) - - # Walk the XPath expression. - xml = x - for tag in criteria: - if not tag: - # Skip empty tag name artifacts from the cleanup phase. - continue - - children = [c.tag.split('}')[-1] for c in xml] - try: - index = children.index(tag) - except ValueError: - return False - xml = list(xml)[index] - return True + return x.find(self._criteria) is not None diff --git a/sleekxmpp/xmlstream/resolver.py b/sleekxmpp/xmlstream/resolver.py index 0d7a8c0d..188e5ac7 100644 --- a/sleekxmpp/xmlstream/resolver.py +++ b/sleekxmpp/xmlstream/resolver.py @@ -32,10 +32,10 @@ log = logging.getLogger(__name__) #: cd dnspython #: git checkout python3 #: python3 setup.py install -USE_DNSPYTHON = False +DNSPYTHON_AVAILABLE = False try: import dns.resolver - USE_DNSPYTHON = True + DNSPYTHON_AVAILABLE = True except ImportError as e: log.debug("Could not find dnspython package. " + \ "Not all features will be available") @@ -47,13 +47,13 @@ def default_resolver(): :returns: A :class:`dns.resolver.Resolver` object if dnspython is available. Otherwise, ``None``. """ - if USE_DNSPYTHON: + if DNSPYTHON_AVAILABLE: return dns.resolver.get_default_resolver() return None def resolve(host, port=None, service=None, proto='tcp', - resolver=None, use_ipv6=True): + resolver=None, use_ipv6=True, use_dnspython=True): """Peform DNS resolution for a given hostname. Resolution may perform SRV record lookups if a service and protocol @@ -77,6 +77,9 @@ def resolve(host, port=None, service=None, proto='tcp', :param use_ipv6: Optionally control the use of IPv6 in situations where it is either not available, or performance is degraded. Defaults to ``True``. + :param use_dnspython: Optionally control if dnspython is used to make + the DNS queries instead of the built-in DNS + library. :type host: string :type port: int @@ -84,14 +87,22 @@ def resolve(host, port=None, service=None, proto='tcp', :type proto: string :type resolver: :class:`dns.resolver.Resolver` :type use_ipv6: bool + :type use_dnspython: bool :return: An iterable of IP address, port pairs in the order dictated by SRV priorities and weights, if applicable. """ + + if not use_dnspython: + if DNSPYTHON_AVAILABLE: + log.debug("DNS: Not using dnspython, but dnspython is installed.") + else: + log.debug("DNS: Not using dnspython.") + if not use_ipv6: log.debug("DNS: Use of IPv6 has been disabled.") - if resolver is None and USE_DNSPYTHON: + if resolver is None and DNSPYTHON_AVAILABLE and use_dnspython: resolver = dns.resolver.get_default_resolver() # An IPv6 literal is allowed to be enclosed in square brackets, but @@ -102,7 +113,7 @@ def resolve(host, port=None, service=None, proto='tcp', try: # If `host` is an IPv4 literal, we can return it immediately. ipv4 = socket.inet_aton(host) - yield (host, port) + yield (host, host, port) except socket.error: pass @@ -112,8 +123,8 @@ def resolve(host, port=None, service=None, proto='tcp', # it immediately. if hasattr(socket, 'inet_pton'): ipv6 = socket.inet_pton(socket.AF_INET6, host) - yield (host, port) - except socket.error: + yield (host, host, port) + except (socket.error, ValueError): pass # If no service was provided, then we can just do A/AAAA lookups on the @@ -122,25 +133,29 @@ def resolve(host, port=None, service=None, proto='tcp', if not service: hosts = [(host, port)] else: - hosts = get_SRV(host, port, service, proto, resolver=resolver) + hosts = get_SRV(host, port, service, proto, + resolver=resolver, + use_dnspython=use_dnspython) for host, port in hosts: results = [] if host == 'localhost': if use_ipv6: - results.append(('::1', port)) - results.append(('127.0.0.1', port)) + results.append((host, '::1', port)) + results.append((host, '127.0.0.1', port)) if use_ipv6: - for address in get_AAAA(host, resolver=resolver): - results.append((address, port)) - for address in get_A(host, resolver=resolver): - results.append((address, port)) + for address in get_AAAA(host, resolver=resolver, + use_dnspython=use_dnspython): + results.append((host, address, port)) + for address in get_A(host, resolver=resolver, + use_dnspython=use_dnspython): + results.append((host, address, port)) - for address, port in results: - yield address, port + for host, address, port in results: + yield host, address, port -def get_A(host, resolver=None): +def get_A(host, resolver=None, use_dnspython=True): """Lookup DNS A records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -148,9 +163,13 @@ def get_A(host, resolver=None): :param host: The hostname to resolve for A record IPv4 addresses. :param resolver: Optional DNS resolver object to use for the query. + :param use_dnspython: Optionally control if dnspython is used to make + the DNS queries instead of the built-in DNS + library. :type host: string :type resolver: :class:`dns.resolver.Resolver` or ``None`` + :type use_dnspython: bool :return: A list of IPv4 literals. """ @@ -158,7 +177,7 @@ def get_A(host, resolver=None): # If not using dnspython, attempt lookup using the OS level # getaddrinfo() method. - if resolver is None: + if resolver is None or not use_dnspython: try: recs = socket.getaddrinfo(host, None, socket.AF_INET, socket.SOCK_STREAM) @@ -183,7 +202,7 @@ def get_A(host, resolver=None): return [] -def get_AAAA(host, resolver=None): +def get_AAAA(host, resolver=None, use_dnspython=True): """Lookup DNS AAAA records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -191,9 +210,13 @@ def get_AAAA(host, resolver=None): :param host: The hostname to resolve for AAAA record IPv6 addresses. :param resolver: Optional DNS resolver object to use for the query. + :param use_dnspython: Optionally control if dnspython is used to make + the DNS queries instead of the built-in DNS + library. :type host: string :type resolver: :class:`dns.resolver.Resolver` or ``None`` + :type use_dnspython: bool :return: A list of IPv6 literals. """ @@ -201,12 +224,15 @@ def get_AAAA(host, resolver=None): # If not using dnspython, attempt lookup using the OS level # getaddrinfo() method. - if resolver is None: + if resolver is None or not use_dnspython: + if not socket.has_ipv6: + log.debug("Unable to query %s for AAAA records: IPv6 is not supported", host) + return [] try: recs = socket.getaddrinfo(host, None, socket.AF_INET6, socket.SOCK_STREAM) return [rec[4][0] for rec in recs] - except socket.gaierror: + except (OSError, socket.gaierror): log.debug("DNS: Error retreiving AAAA address " + \ "info for %s." % host) return [] @@ -227,7 +253,7 @@ def get_AAAA(host, resolver=None): return [] -def get_SRV(host, port, service, proto='tcp', resolver=None): +def get_SRV(host, port, service, proto='tcp', resolver=None, use_dnspython=True): """Perform SRV record resolution for a given host. .. note:: @@ -253,7 +279,7 @@ def get_SRV(host, port, service, proto='tcp', resolver=None): :return: A list of hostname, port pairs in the order dictacted by SRV priorities and weights. """ - if resolver is None: + if resolver is None or not use_dnspython: log.warning("DNS: dnspython not found. Can not use SRV lookup.") return [(host, port)] @@ -297,7 +323,10 @@ def get_SRV(host, port, service, proto='tcp', resolver=None): for running_sum in sums: if running_sum >= selected: rec = sums[running_sum] - sorted_recs.append((rec.target.to_text(), rec.port)) + host = rec.target.to_text() + if host.endswith('.'): + host = host[:-1] + sorted_recs.append((host, rec.port)) answers[priority].remove(rec) break diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py index f68af081..e6fae37a 100644 --- a/sleekxmpp/xmlstream/scheduler.py +++ b/sleekxmpp/xmlstream/scheduler.py @@ -15,10 +15,14 @@ import time import threading import logging -try: - import queue -except ImportError: - import Queue as queue +import itertools + +from sleekxmpp.util import Queue, QueueEmpty + + +#: The time in seconds to wait for events from the event queue, and also the +#: time between checks for the process stop signal. +WAIT_TIMEOUT = 1.0 log = logging.getLogger(__name__) @@ -77,7 +81,7 @@ class Task(object): """ if self.qpointer is not None: self.qpointer.put(('schedule', self.callback, - self.args, self.name)) + self.args, self.kwargs, self.name)) else: self.callback(*self.args, **self.kwargs) self.reset() @@ -102,7 +106,7 @@ class Scheduler(object): def __init__(self, parentstop=None): #: A queue for storing tasks - self.addq = queue.Queue() + self.addq = Queue() #: A list of tasks in order of execution time. self.schedule = [] @@ -121,6 +125,10 @@ class Scheduler(object): #: Lock for accessing the task queue. self.schedule_lock = threading.RLock() + #: The time in seconds to wait for events from the event queue, + #: and also the time between checks for the process stop signal. + self.wait_timeout = WAIT_TIMEOUT + def process(self, threaded=True, daemon=False): """Begin accepting and processing scheduled tasks. @@ -140,44 +148,50 @@ class Scheduler(object): self.run = True try: while self.run and not self.stop.is_set(): - wait = 0.1 updated = False if self.schedule: wait = self.schedule[0].next - time.time() + else: + wait = self.wait_timeout try: if wait <= 0.0: newtask = self.addq.get(False) else: - if wait >= 3.0: - wait = 3.0 newtask = None - elapsed = 0 - while not self.stop.is_set() and \ + while self.run and \ + not self.stop.is_set() and \ newtask is None and \ - elapsed < wait: - newtask = self.addq.get(True, 0.1) - elapsed += 0.1 - except queue.Empty: - cleanup = [] + wait > 0: + try: + newtask = self.addq.get(True, min(wait, self.wait_timeout)) + except QueueEmpty: # Nothing to add, nothing to do. Check run flags and continue waiting. + wait -= self.wait_timeout + except QueueEmpty: # Time to run some tasks, and no new tasks to add. self.schedule_lock.acquire() - for task in self.schedule: - if time.time() >= task.next: - updated = True - if not task.run(): - cleanup.append(task) + # select only those tasks which are to be executed now + relevant = itertools.takewhile( + lambda task: time.time() >= task.next, self.schedule) + # run the tasks and keep the return value in a tuple + status = map(lambda task: (task, task.run()), relevant) + # remove non-repeating tasks + for task, doRepeat in status: + if not doRepeat: + try: + self.schedule.remove(task) + except ValueError: + pass else: - break - for task in cleanup: - self.schedule.pop(self.schedule.index(task)) - else: - updated = True + # only need to resort tasks if a repeated task has + # been kept in the list. + updated = True + else: # Add new task self.schedule_lock.acquire() if newtask is not None: self.schedule.append(newtask) + updated = True finally: if updated: - self.schedule = sorted(self.schedule, - key=lambda task: task.next) + self.schedule.sort(key=lambda task: task.next) self.schedule_lock.release() except KeyboardInterrupt: self.run = False diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py index 4af441cc..11c8dd67 100644 --- a/sleekxmpp/xmlstream/stanzabase.py +++ b/sleekxmpp/xmlstream/stanzabase.py @@ -3,7 +3,7 @@ sleekxmpp.xmlstream.stanzabase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - This module implements a wrapper layer for XML objects + module implements a wrapper layer for XML objects that allows them to be treated like dictionaries. Part of SleekXMPP: The Sleek XMPP Library @@ -19,6 +19,7 @@ import logging import weakref from xml.etree import cElementTree as ET +from sleekxmpp.util import safedict from sleekxmpp.xmlstream import JID from sleekxmpp.xmlstream.tostring import tostring from sleekxmpp.thirdparty import OrderedDict @@ -141,7 +142,7 @@ def multifactory(stanza, plugin_attrib): parent.loaded_plugins.remove(plugin_attrib) try: parent.xml.remove(self.xml) - except: + except ValueError: pass else: for stanza in list(res): @@ -192,7 +193,7 @@ def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''): for element in elements: if element: # Skip empty entry artifacts from splitting. - if propagate_ns: + if propagate_ns and element[0] != '*': tag = '{%s}%s' % (namespace, element) else: tag = element @@ -488,7 +489,7 @@ class ElementBase(object): """ return self.init_plugin(attrib, lang) - def _get_plugin(self, name, lang=None): + def _get_plugin(self, name, lang=None, check=False): if lang is None: lang = self.get_lang() @@ -501,12 +502,12 @@ class ElementBase(object): if (name, None) in self.plugins: return self.plugins[(name, None)] else: - return self.init_plugin(name, lang) + return None if check else self.init_plugin(name, lang) else: if (name, lang) in self.plugins: return self.plugins[(name, lang)] else: - return self.init_plugin(name, lang) + return None if check else self.init_plugin(name, lang) def init_plugin(self, attrib, lang=None, existing_xml=None, reuse=True): """Enable and initialize a stanza plugin. @@ -514,8 +515,9 @@ class ElementBase(object): :param string attrib: The :attr:`plugin_attrib` value of the plugin to enable. """ - if lang is None: - lang = self.get_lang() + default_lang = self.get_lang() + if not lang: + lang = default_lang plugin_class = self.plugin_attrib_map[attrib] @@ -524,19 +526,13 @@ class ElementBase(object): if reuse and (attrib, lang) in self.plugins: return self.plugins[(attrib, lang)] - if existing_xml is None: - existing_xml = self.xml.find(plugin_class.tag_name()) - - if existing_xml is not None: - if existing_xml.attrib.get('{%s}lang' % XML_NS, '') != lang: - existing_xml = None - plugin = plugin_class(parent=self, xml=existing_xml) if plugin.is_extension: self.plugins[(attrib, None)] = plugin else: - plugin['lang'] = lang + if lang != default_lang: + plugin['lang'] = lang self.plugins[(attrib, lang)] = plugin if plugin_class in self.plugin_iterables: @@ -570,13 +566,16 @@ class ElementBase(object): values = {} values['lang'] = self['lang'] for interface in self.interfaces: - values[interface] = self[interface] + if isinstance(self[interface], JID): + values[interface] = self[interface].jid + else: + values[interface] = self[interface] if interface in self.lang_interfaces: values['%s|*' % interface] = self['%s|*' % interface] for plugin, stanza in self.plugins.items(): lang = stanza['lang'] if lang: - values['%s|%s' % (plugin, lang)] = stanza.values + values['%s|%s' % (plugin[0], lang)] = stanza.values else: values[plugin[0]] = stanza.values if self.iterables: @@ -601,31 +600,39 @@ class ElementBase(object): iterable_interfaces = [p.plugin_attrib for \ p in self.plugin_iterables] + if 'lang' in values: + self['lang'] = values['lang'] + + if 'substanzas' in values: + # Remove existing substanzas + for stanza in self.iterables: + try: + self.xml.remove(stanza.xml) + except ValueError: + pass + self.iterables = [] + + # Add new substanzas + for subdict in values['substanzas']: + if '__childtag__' in subdict: + for subclass in self.plugin_iterables: + child_tag = "{%s}%s" % (subclass.namespace, + subclass.name) + if subdict['__childtag__'] == child_tag: + sub = subclass(parent=self) + sub.values = subdict + self.iterables.append(sub) + for interface, value in values.items(): full_interface = interface interface_lang = ('%s|' % interface).split('|') interface = interface_lang[0] lang = interface_lang[1] or self.get_lang() - if interface == 'substanzas': - # Remove existing substanzas - for stanza in self.iterables: - self.xml.remove(stanza.xml) - self.iterables = [] - - # Add new substanzas - for subdict in value: - if '__childtag__' in subdict: - for subclass in self.plugin_iterables: - child_tag = "{%s}%s" % (subclass.namespace, - subclass.name) - if subdict['__childtag__'] == child_tag: - sub = subclass(parent=self) - sub.values = subdict - self.iterables.append(sub) - break - elif interface == 'lang': - self[interface] = value + if interface == 'lang': + continue + elif interface == 'substanzas': + continue elif interface in self.interfaces: self[full_interface] = value elif interface in self.plugin_attrib_map: @@ -667,12 +674,14 @@ class ElementBase(object): full_attrib = attrib attrib_lang = ('%s|' % attrib).split('|') attrib = attrib_lang[0] - lang = attrib_lang[1] or '' + lang = attrib_lang[1] or None kwargs = {} if lang and attrib in self.lang_interfaces: kwargs['lang'] = lang + kwargs = safedict(kwargs) + if attrib == 'substanzas': return self.iterables elif attrib in self.interfaces or attrib == 'lang': @@ -743,12 +752,14 @@ class ElementBase(object): full_attrib = attrib attrib_lang = ('%s|' % attrib).split('|') attrib = attrib_lang[0] - lang = attrib_lang[1] or '' + lang = attrib_lang[1] or None kwargs = {} if lang and attrib in self.lang_interfaces: kwargs['lang'] = lang + kwargs = safedict(kwargs) + if attrib in self.interfaces or attrib == 'lang': if value is not None: set_method = "set_%s" % attrib.lower() @@ -829,12 +840,14 @@ class ElementBase(object): full_attrib = attrib attrib_lang = ('%s|' % attrib).split('|') attrib = attrib_lang[0] - lang = attrib_lang[1] or '' + lang = attrib_lang[1] or None kwargs = {} if lang and attrib in self.lang_interfaces: kwargs['lang'] = lang + kwargs = safedict(kwargs) + if attrib in self.interfaces or attrib == 'lang': del_method = "del_%s" % attrib.lower() del_method2 = "del%s" % attrib.title() @@ -860,18 +873,18 @@ class ElementBase(object): else: self._del_attr(attrib) elif attrib in self.plugin_attrib_map: - plugin = self._get_plugin(attrib, lang) + plugin = self._get_plugin(attrib, lang, check=True) if not plugin: return self if plugin.is_extension: del plugin[full_attrib] del self.plugins[(attrib, None)] else: - del self.plugins[(attrib, lang)] + del self.plugins[(attrib, plugin['lang'])] self.loaded_plugins.remove(attrib) try: self.xml.remove(plugin.xml) - except: + except ValueError: pass return self @@ -1222,6 +1235,10 @@ class ElementBase(object): if item.__class__ in self.plugin_iterables: if item.__class__.plugin_multi_attrib: self.init_plugin(item.__class__.plugin_multi_attrib) + elif item.__class__ == self.plugin_tag_map.get(item.tag_name(), None): + self.init_plugin(item.plugin_attrib, + existing_xml=item.xml, + reuse=False) return self def appendxml(self, xml): @@ -1398,10 +1415,8 @@ class ElementBase(object): :param bool top_level_ns: Display the top-most namespace. Defaults to True. """ - stanza_ns = '' if top_level_ns else self.namespace return tostring(self.xml, xmlns='', - stanza_ns=stanza_ns, - top_level=not top_level_ns) + top_level=True) def __repr__(self): """Use the stanza's serialized XML as its representation.""" @@ -1590,11 +1605,10 @@ class StanzaBase(ElementBase): :param bool top_level_ns: Display the top-most namespace. Defaults to ``False``. """ - stanza_ns = '' if top_level_ns else self.namespace - return tostring(self.xml, xmlns='', - stanza_ns=stanza_ns, + xmlns = self.stream.default_ns if self.stream else '' + return tostring(self.xml, xmlns=xmlns, stream=self.stream, - top_level=not top_level_ns) + top_level=(self.stream is None)) #: A JSON/dictionary version of the XML content exposed through diff --git a/sleekxmpp/xmlstream/tostring.py b/sleekxmpp/xmlstream/tostring.py index 2480f9b2..c49abd3e 100644 --- a/sleekxmpp/xmlstream/tostring.py +++ b/sleekxmpp/xmlstream/tostring.py @@ -24,25 +24,25 @@ if sys.version_info < (3, 0): XML_NS = 'http://www.w3.org/XML/1998/namespace' -def tostring(xml=None, xmlns='', stanza_ns='', stream=None, - outbuffer='', top_level=False, open_only=False): +def tostring(xml=None, xmlns='', stream=None, outbuffer='', + top_level=False, open_only=False, namespaces=None): """Serialize an XML object to a Unicode string. - If namespaces are provided using ``xmlns`` or ``stanza_ns``, then - elements that use those namespaces will not include the xmlns attribute - in the output. + If an outer xmlns is provided using ``xmlns``, then the current element's + namespace will not be included if it matches the outer namespace. An + exception is made for elements that have an attached stream, and appear + at the stream root. :param XML xml: The XML object to serialize. :param string xmlns: Optional namespace of an element wrapping the XML object. - :param string stanza_ns: The namespace of the stanza object that contains - the XML object. :param stream: The XML stream that generated the XML object. :param string outbuffer: Optional buffer for storing serializations during recursive calls. :param bool top_level: Indicates that the element is the outermost element. - + :param set namespaces: Track which namespaces are in active use so + that new ones can be declared when needed. :type xml: :py:class:`~xml.etree.ElementTree.Element` :type stream: :class:`~sleekxmpp.xmlstream.xmlstream.XMLStream` @@ -63,15 +63,19 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, default_ns = '' stream_ns = '' + use_cdata = False + if stream: default_ns = stream.default_ns stream_ns = stream.stream_ns + use_cdata = stream.use_cdata # Output the tag name and derived namespace of the element. namespace = '' - if top_level and tag_xmlns not in ['', default_ns, stream_ns] or \ - tag_xmlns not in ['', xmlns, stanza_ns, stream_ns]: - namespace = ' xmlns="%s"' % tag_xmlns + if tag_xmlns: + if top_level and tag_xmlns not in [default_ns, xmlns, stream_ns] \ + or not top_level and tag_xmlns != xmlns: + namespace = ' xmlns="%s"' % tag_xmlns if stream and tag_xmlns in stream.namespace_map: mapped_namespace = stream.namespace_map[tag_xmlns] if mapped_namespace: @@ -80,21 +84,28 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, output.append(namespace) # Output escaped attribute values. + new_namespaces = set() for attrib, value in xml.attrib.items(): - value = xml_escape(value) + value = escape(value, use_cdata) if '}' not in attrib: output.append(' %s="%s"' % (attrib, value)) else: attrib_ns = attrib.split('}')[0][1:] attrib = attrib.split('}')[1] - if stream and attrib_ns in stream.namespace_map: + if attrib_ns == XML_NS: + output.append(' xml:%s="%s"' % (attrib, value)) + elif stream and attrib_ns in stream.namespace_map: mapped_ns = stream.namespace_map[attrib_ns] if mapped_ns: - output.append(' %s:%s="%s"' % (mapped_ns, - attrib, - value)) - elif attrib_ns == XML_NS: - output.append(' xml:%s="%s"' % (attrib, value)) + if namespaces is None: + namespaces = set() + if attrib_ns not in namespaces: + namespaces.add(attrib_ns) + new_namespaces.add(attrib_ns) + output.append(' xmlns:%s="%s"' % ( + mapped_ns, attrib_ns)) + output.append(' %s:%s="%s"' % ( + mapped_ns, attrib, value)) if open_only: # Only output the opening tag, regardless of content. @@ -105,24 +116,30 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, # If there are additional child elements to serialize. output.append(">") if xml.text: - output.append(xml_escape(xml.text)) + output.append(escape(xml.text, use_cdata)) if len(xml): for child in xml: - output.append(tostring(child, tag_xmlns, stanza_ns, stream)) + output.append(tostring(child, tag_xmlns, stream, + namespaces=namespaces)) output.append("</%s>" % tag_name) elif xml.text: # If we only have text content. - output.append(">%s</%s>" % (xml_escape(xml.text), tag_name)) + output.append(">%s</%s>" % (escape(xml.text, use_cdata), tag_name)) else: # Empty element. output.append(" />") if xml.tail: # If there is additional text after the element. - output.append(xml_escape(xml.tail)) + output.append(escape(xml.tail, use_cdata)) + for ns in new_namespaces: + # Remove namespaces introduced in this context. This is necessary + # because the namespaces object continues to be shared with other + # contexts. + namespaces.remove(ns) return ''.join(output) -def xml_escape(text): +def escape(text, use_cdata=False): """Convert special characters in XML to escape sequences. :param string text: The XML text to convert. @@ -132,12 +149,24 @@ def xml_escape(text): if type(text) != types.UnicodeType: text = unicode(text, 'utf-8', 'ignore') - text = list(text) escapes = {'&': '&', '<': '<', '>': '>', "'": ''', '"': '"'} - for i, c in enumerate(text): - text[i] = escapes.get(c, c) - return ''.join(text) + + if not use_cdata: + text = list(text) + for i, c in enumerate(text): + text[i] = escapes.get(c, c) + return ''.join(text) + else: + escape_needed = False + for c in text: + if c in escapes: + escape_needed = True + break + if escape_needed: + escaped = map(lambda x : "<![CDATA[%s]]>" % x, text.split("]]>")) + return "<![CDATA[]]]><![CDATA[]>]]>".join(escaped) + return text diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index 49f33933..f9ec4947 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -26,14 +26,12 @@ import time import random import weakref import uuid -try: - import queue -except ImportError: - import Queue as queue +import errno from xml.parsers.expat import ExpatError import sleekxmpp +from sleekxmpp.util import Queue, QueueEmpty, safedict from sleekxmpp.thirdparty.statemachine import StateMachine from sleekxmpp.xmlstream import Scheduler, tostring, cert from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase @@ -52,7 +50,7 @@ RESPONSE_TIMEOUT = 30 #: The time in seconds to wait for events from the event queue, and also the #: time between checks for the process stop signal. -WAIT_TIMEOUT = 0.1 +WAIT_TIMEOUT = 1.0 #: The number of threads to use to handle XML stream events. This is not the #: same as the number of custom event handling threads. @@ -61,9 +59,6 @@ WAIT_TIMEOUT = 0.1 #: a GIL increasing this value can provide better performance. HANDLER_THREADS = 1 -#: Flag indicating if the SSL library is available for use. -SSL_SUPPORT = True - #: The time in seconds to delay between attempts to resend data #: after an SSL error. SSL_RETRY_DELAY = 0.5 @@ -120,9 +115,6 @@ class XMLStream(object): """ def __init__(self, socket=None, host='', port=0): - #: Flag indicating if the SSL library is available for use. - self.ssl_support = SSL_SUPPORT - #: Most XMPP servers support TLSv1, but OpenFire in particular #: does not work well with it. For OpenFire, set #: :attr:`ssl_version` to use ``SSLv23``:: @@ -131,6 +123,11 @@ class XMLStream(object): #: xmpp.ssl_version = ssl.PROTOCOL_SSLv23 self.ssl_version = ssl.PROTOCOL_TLSv1 + #: The list of accepted ciphers, in OpenSSL Format. + #: It might be useful to override it for improved security + #: over the python defaults. + self.ciphers = None + #: Path to a file containing certificates for verifying the #: server SSL certificate. A non-``None`` value will trigger #: certificate checking. @@ -141,6 +138,17 @@ class XMLStream(object): #: be consulted, even if they are not in the provided file. self.ca_certs = None + #: Path to a file containing a client certificate to use for + #: authenticating via SASL EXTERNAL. If set, there must also + #: be a corresponding `:attr:keyfile` value. + self.certfile = None + + #: Path to a file containing the private key for the selected + #: client certificate to use for authenticating via SASL EXTERNAL. + self.keyfile = None + + self._der_cert = None + #: The time in seconds to wait for events from the event queue, #: and also the time between checks for the process stop signal. self.wait_timeout = WAIT_TIMEOUT @@ -184,6 +192,7 @@ class XMLStream(object): #: The expected name of the server, for validation. self._expected_server_name = '' + self._service_name = '' #: The desired, or actual, address of the connected server. self.address = (host, int(port)) @@ -215,6 +224,15 @@ class XMLStream(object): #: If set to ``True``, attempt to use IPv6. self.use_ipv6 = True + #: If set to ``True``, allow using the ``dnspython`` DNS library + #: if available. If set to ``False``, the builtin DNS resolver + #: will be used, even if ``dnspython`` is installed. + self.use_dnspython = True + + #: Use CDATA for escaping instead of XML entities. Defaults + #: to ``False``. + self.use_cdata = False + #: An optional dictionary of proxy settings. It may provide: #: :host: The host offering proxy services. #: :port: The port for the proxy service. @@ -270,10 +288,10 @@ class XMLStream(object): self.end_session_on_disconnect = True #: A queue of stream, custom, and scheduled events to be processed. - self.event_queue = queue.Queue() + self.event_queue = Queue() #: A queue of string data to be sent over the stream. - self.send_queue = queue.Queue() + self.send_queue = Queue(maxsize=256) self.send_queue_lock = threading.Lock() self.send_lock = threading.RLock() @@ -322,7 +340,7 @@ class XMLStream(object): #: ``_xmpp-client._tcp`` service. self.dns_service = None - self.add_event_handler('connected', self._handle_connected) + self.add_event_handler('connected', self._session_timeout_check) self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('session_start', self._cert_expiration) @@ -407,6 +425,8 @@ class XMLStream(object): :param reattempt: Flag indicating if the socket should reconnect after disconnections. """ + self.stop.clear() + if host and port: self.address = (host, int(port)) try: @@ -439,11 +459,12 @@ class XMLStream(object): def _connect(self, reattempt=True): self.scheduler.remove('Session timeout check') - self.stop.clear() - if self.reconnect_delay is None or not reattempt: + if self.reconnect_delay is None: delay = 1.0 - else: + self.reconnect_delay = delay + + if reattempt: delay = min(self.reconnect_delay * 2, self.reconnect_max_delay) delay = random.normalvariate(delay, delay * 0.1) log.debug('Waiting %s seconds before connecting.', delay) @@ -453,16 +474,18 @@ class XMLStream(object): time.sleep(0.1) elapsed += 0.1 except KeyboardInterrupt: - self.stop.set() + self.set_stop() return False except SystemExit: - self.stop.set() + self.set_stop() return False if self.default_domain: try: - self.address = self.pick_dns_answer(self.default_domain, - self.address[1]) + host, address, port = self.pick_dns_answer(self.default_domain, + self.address[1]) + self.address = (address, port) + self._service_name = host except StopIteration: log.debug("No remaining DNS records to try.") self.dns_answers = None @@ -490,17 +513,26 @@ class XMLStream(object): self.reconnect_delay = delay return False - if self.use_ssl and self.ssl_support: + if self.use_ssl: log.debug("Socket Wrapped for SSL") if self.ca_certs is None: cert_policy = ssl.CERT_NONE else: cert_policy = ssl.CERT_REQUIRED - ssl_socket = ssl.wrap_socket(self.socket, - ca_certs=self.ca_certs, - cert_reqs=cert_policy, - do_handshake_on_connect=False) + ssl_args = safedict({ + 'certfile': self.certfile, + 'keyfile': self.keyfile, + 'ca_certs': self.ca_certs, + 'cert_reqs': cert_policy, + 'do_handshake_on_connect': False, + "ssl_version": self.ssl_version + }) + + if sys.version_info >= (2, 7): + ssl_args['ciphers'] = self.ciphers + + ssl_socket = ssl.wrap_socket(self.socket, **ssl_args) if hasattr(self.socket, 'socket'): # We are using a testing socket, so preserve the top @@ -517,7 +549,7 @@ class XMLStream(object): log.debug("Connecting to %s:%s", domain, self.address[1]) self.socket.connect(self.address) - if self.use_ssl and self.ssl_support: + if self.use_ssl: try: self.socket.do_handshake() except (Socket.error, ssl.SSLError): @@ -538,7 +570,7 @@ class XMLStream(object): cert.verify(self._expected_server_name, self._der_cert) except cert.CertificateError as err: if not self.event_handled('ssl_invalid_cert'): - log.error(err.message) + log.error(err) self.disconnect(send_close=False) else: self.event('ssl_invalid_cert', @@ -547,8 +579,7 @@ class XMLStream(object): self.set_socket(self.socket, ignore=True) #this event is where you should set your application state - self.event("connected", direct=True) - self.reconnect_delay = 1.0 + self.event('connected', direct=True) return True except (Socket.error, ssl.SSLError) as serr: error_msg = "Could not connect to %s:%s. Socket Error #%s: %s" @@ -588,7 +619,7 @@ class XMLStream(object): headers = '\r\n'.join(headers) + '\r\n\r\n' try: - log.debug("Connecting to proxy: %s:%s", address) + log.debug("Connecting to proxy: %s:%s", *address) self.socket.connect(address) self.send_raw(headers, now=True) resp = '' @@ -599,6 +630,7 @@ class XMLStream(object): lines = resp.split('\r\n') if '200' not in lines[0]: self.event('proxy_error', resp) + self.event('connection_failed', direct=True) log.error('Proxy Error: %s', lines[0]) return False @@ -612,7 +644,7 @@ class XMLStream(object): serr.errno, serr.strerror) return False - def _handle_connected(self, event=None): + def _session_timeout_check(self, event=None): """ Add check to ensure that a session is established within a reasonable amount of time. @@ -661,6 +693,9 @@ class XMLStream(object): args=(reconnect, wait, send_close)) def _disconnect(self, reconnect=False, wait=None, send_close=True): + if not reconnect: + self.auto_reconnect = False + if self.end_session_on_disconnect or send_close: self.event('session_end', direct=True) @@ -684,7 +719,6 @@ class XMLStream(object): # closed in the other direction. If we didn't # send a stream footer we don't need to wait # since the server won't know to respond. - self.auto_reconnect = reconnect if send_close: log.info('Waiting for %s from server', self.stream_footer) self.stream_end_event.wait(4) @@ -692,7 +726,7 @@ class XMLStream(object): self.stream_end_event.set() if not self.auto_reconnect: - self.stop.set() + self.set_stop() if self._disconnect_wait_for_threads: self._wait_for_threads() @@ -704,9 +738,23 @@ class XMLStream(object): self.event('socket_error', serr, direct=True) finally: #clear your application state - self.event("disconnected", direct=True) + self.event('disconnected', direct=True) return True + def abort(self): + self.session_started_event.clear() + self.set_stop() + if self._disconnect_wait_for_threads: + self._wait_for_threads() + try: + self.socket.shutdown(Socket.SHUT_RDWR) + self.socket.close() + self.filesocket.close() + except Socket.error: + pass + self.state.transition_any(['connected', 'disconnected'], 'disconnected', func=lambda: True) + self.event("killed", direct=True) + def reconnect(self, reattempt=True, wait=False, send_close=True): """Reset the stream's state and reconnect to the server.""" log.debug("reconnecting...") @@ -789,56 +837,62 @@ class XMLStream(object): If the handshake is successful, the XML stream will need to be restarted. """ - if self.ssl_support: - log.info("Negotiating TLS") - log.info("Using SSL version: %s", str(self.ssl_version)) - if self.ca_certs is None: - cert_policy = ssl.CERT_NONE - else: - cert_policy = ssl.CERT_REQUIRED - - ssl_socket = ssl.wrap_socket(self.socket, - ssl_version=self.ssl_version, - do_handshake_on_connect=False, - ca_certs=self.ca_certs, - cert_reqs=cert_policy) + log.info("Negotiating TLS") + ssl_versions = {3: 'TLS 1.0', 1: 'SSL 3', 2: 'SSL 2/3'} + log.info("Using SSL version: %s", ssl_versions[self.ssl_version]) + if self.ca_certs is None: + cert_policy = ssl.CERT_NONE + else: + cert_policy = ssl.CERT_REQUIRED + + ssl_args = safedict({ + 'certfile': self.certfile, + 'keyfile': self.keyfile, + 'ca_certs': self.ca_certs, + 'cert_reqs': cert_policy, + 'do_handshake_on_connect': False, + "ssl_version": self.ssl_version + }) + + if sys.version_info >= (2, 7): + ssl_args['ciphers'] = self.ciphers + + ssl_socket = ssl.wrap_socket(self.socket, **ssl_args) + + if hasattr(self.socket, 'socket'): + # We are using a testing socket, so preserve the top + # layer of wrapping. + self.socket.socket = ssl_socket + else: + self.socket = ssl_socket - if hasattr(self.socket, 'socket'): - # We are using a testing socket, so preserve the top - # layer of wrapping. - self.socket.socket = ssl_socket + try: + self.socket.do_handshake() + except (Socket.error, ssl.SSLError): + log.error('CERT: Invalid certificate trust chain.') + if not self.event_handled('ssl_invalid_chain'): + self.disconnect(self.auto_reconnect, send_close=False) else: - self.socket = ssl_socket - - try: - self.socket.do_handshake() - except (Socket.error, ssl.SSLError): - log.error('CERT: Invalid certificate trust chain.') - if not self.event_handled('ssl_invalid_chain'): - self.disconnect(self.auto_reconnect, send_close=False) - else: - self.event('ssl_invalid_chain', direct=True) - return False + self._der_cert = self.socket.getpeercert(binary_form=True) + self.event('ssl_invalid_chain', direct=True) + return False - self._der_cert = self.socket.getpeercert(binary_form=True) - pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) - log.debug('CERT: %s', pem_cert) - self.event('ssl_cert', pem_cert, direct=True) + self._der_cert = self.socket.getpeercert(binary_form=True) + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + log.debug('CERT: %s', pem_cert) + self.event('ssl_cert', pem_cert, direct=True) - try: - cert.verify(self._expected_server_name, self._der_cert) - except cert.CertificateError as err: - if not self.event_handled('ssl_invalid_cert'): - log.error(err.message) - self.disconnect(self.auto_reconnect, send_close=False) - else: - self.event('ssl_invalid_cert', pem_cert, direct=True) + try: + cert.verify(self._expected_server_name, self._der_cert) + except cert.CertificateError as err: + if not self.event_handled('ssl_invalid_cert'): + log.error(err) + self.disconnect(self.auto_reconnect, send_close=False) + else: + self.event('ssl_invalid_cert', pem_cert, direct=True) - self.set_socket(self.socket) - return True - else: - log.warning("Tried to enable TLS, but ssl module not found.") - return False + self.set_socket(self.socket) + return True def _cert_expiration(self, event): """Schedule an event for when the TLS certificate expires.""" @@ -866,9 +920,15 @@ class XMLStream(object): log.warn('CERT: Certificate has expired.') restart() + try: + total_seconds = cert_ttl.total_seconds() + except AttributeError: + # for Python < 2.7 + total_seconds = (cert_ttl.microseconds + (cert_ttl.seconds + cert_ttl.days * 24 * 3600) * 10**6) / 10**6 + log.info('CERT: Time until certificate expiration: %s' % cert_ttl) self.schedule('Certificate Expiration', - cert_ttl.seconds, + total_seconds, restart) def _start_keepalive(self, event): @@ -882,12 +942,13 @@ class XMLStream(object): self.whitespace_keepalive_interval = 300 """ - self.schedule('Whitespace Keepalive', - self.whitespace_keepalive_interval, - self.send_raw, - args=(' ',), - kwargs={'now': True}, - repeat=True) + if self.whitespace_keepalive: + self.schedule('Whitespace Keepalive', + self.whitespace_keepalive_interval, + self.send_raw, + args=(' ',), + kwargs={'now': True}, + repeat=True) def _remove_schedules(self, event): """Remove whitespace keepalive and certificate expiration schedules.""" @@ -983,9 +1044,13 @@ class XMLStream(object): # and handler classes here. if name is None: - name = 'add_handler_%s' % self.getNewId() - self.registerHandler(XMLCallback(name, MatchXMLMask(mask), pointer, - once=disposable, instream=instream)) + name = 'add_handler_%s' % self.new_id() + self.register_handler( + XMLCallback(name, + MatchXMLMask(mask, self.default_ns), + pointer, + once=disposable, + instream=instream)) def register_handler(self, handler, before=None, after=None): """Add a stream event handler that will be executed when a matching @@ -1026,7 +1091,8 @@ class XMLStream(object): return resolve(domain, port, service=self.dns_service, resolver=resolver, - use_ipv6=self.use_ipv6) + use_ipv6=self.use_ipv6, + use_dnspython=self.use_dnspython) def pick_dns_answer(self, domain, port=None): """Pick a server and port from DNS answers. @@ -1087,7 +1153,7 @@ class XMLStream(object): """ return len(self.__event_handlers.get(name, [])) - def event(self, name, data={}, direct=False): + def event(self, name, data=None, direct=False): """Manually trigger a custom event. :param name: The name of the event to trigger. @@ -1098,6 +1164,11 @@ class XMLStream(object): event queue. All event handlers will run in the same thread. """ + if not data: + data = {} + + log.debug("Event triggered: " + name) + handlers = self.__event_handlers.get(name, []) for handler in handlers: #TODO: Data should not be copied, but should be read only, @@ -1202,7 +1273,9 @@ class XMLStream(object): data = filter(data) if data is None: return - str_data = str(data) + str_data = tostring(data.xml, xmlns=self.default_ns, + stream=self, + top_level=True) self.send_raw(str_data, now) else: self.send_raw(data, now) @@ -1267,6 +1340,9 @@ class XMLStream(object): if not self.stop.is_set(): time.sleep(self.ssl_retry_delay) tries += 1 + except Socket.error as serr: + if serr.errno != errno.EINTR: + raise if count > 1: log.debug('SENT: %d chunks', count) except (Socket.error, ssl.SSLError) as serr: @@ -1281,12 +1357,12 @@ class XMLStream(object): return True def _start_thread(self, name, target, track=True): - self.__active_threads.add(name) self.__thread[name] = threading.Thread(name=name, target=target) self.__thread[name].daemon = self._use_daemons self.__thread[name].start() if track: + self.__active_threads.add(name) with self.__thread_cond: self.__thread_count += 1 @@ -1315,6 +1391,13 @@ class XMLStream(object): if self.__thread_count == 0: self.__thread_cond.notify() + def set_stop(self): + self.stop.set() + + # Unlock queues + self.event_queue.put(None) + self.send_queue.put(None) + def _wait_for_threads(self): with self.__thread_cond: if self.__thread_count != 0: @@ -1458,6 +1541,10 @@ class XMLStream(object): # as handshakes. self.stream_end_event.clear() self.start_stream_handler(root) + + # We have a successful stream connection, so reset + # exponential backoff for new reconnect attempts. + self.reconnect_delay = 1.0 depth += 1 if event == b'end': depth -= 1 @@ -1583,11 +1670,7 @@ class XMLStream(object): log.debug("Loading event runner") try: while not self.stop.is_set(): - try: - wait = self.wait_timeout - event = self.event_queue.get(True, timeout=wait) - except queue.Empty: - event = None + event = self.event_queue.get() if event is None: continue @@ -1603,10 +1686,10 @@ class XMLStream(object): log.exception(error_msg, handler.name) orig.exception(e) elif etype == 'schedule': - name = args[1] + name = args[2] try: log.debug('Scheduled event: %s: %s', name, args[0]) - handler(*args[0]) + handler(*args[0], **args[1]) except Exception as e: log.exception('Error processing scheduled task') self.exception(e) @@ -1648,14 +1731,13 @@ class XMLStream(object): while not self.stop.is_set(): while not self.stop.is_set() and \ not self.session_started_event.is_set(): - self.session_started_event.wait(timeout=0.1) + self.session_started_event.wait(timeout=0.1) # Wait for session start if self.__failed_send_stanza is not None: data = self.__failed_send_stanza self.__failed_send_stanza = None else: - try: - data = self.send_queue.get(True, 1) - except queue.Empty: + data = self.send_queue.get() # Wait for data to send + if data is None: continue log.debug("SEND: %s", data) enc_data = data.encode('utf-8') @@ -1682,6 +1764,9 @@ class XMLStream(object): if not self.stop.is_set(): time.sleep(self.ssl_retry_delay) tries += 1 + except Socket.error as serr: + if serr.errno != errno.EINTR: + raise if count > 1: log.debug('SENT: %d chunks', count) self.send_queue.task_done() |