summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream
diff options
context:
space:
mode:
Diffstat (limited to 'sleekxmpp/xmlstream')
-rw-r--r--sleekxmpp/xmlstream/cert.py4
-rw-r--r--sleekxmpp/xmlstream/filesocket.py9
-rw-r--r--sleekxmpp/xmlstream/matcher/__init__.py1
-rw-r--r--sleekxmpp/xmlstream/matcher/idsender.py47
-rw-r--r--sleekxmpp/xmlstream/matcher/xmlmask.py71
-rw-r--r--sleekxmpp/xmlstream/matcher/xpath.py37
-rw-r--r--sleekxmpp/xmlstream/resolver.py58
-rw-r--r--sleekxmpp/xmlstream/scheduler.py34
-rw-r--r--sleekxmpp/xmlstream/stanzabase.py66
-rw-r--r--sleekxmpp/xmlstream/tostring.py35
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py154
11 files changed, 313 insertions, 203 deletions
diff --git a/sleekxmpp/xmlstream/cert.py b/sleekxmpp/xmlstream/cert.py
index fa12f794..71146f36 100644
--- a/sleekxmpp/xmlstream/cert.py
+++ b/sleekxmpp/xmlstream/cert.py
@@ -1,6 +1,10 @@
import logging
from datetime import datetime, timedelta
+# Make a call to strptime before starting threads to
+# prevent thread safety issues.
+datetime.strptime('1970-01-01 12:00:00', "%Y-%m-%d %H:%M:%S")
+
try:
from pyasn1.codec.der import decoder, encoder
diff --git a/sleekxmpp/xmlstream/filesocket.py b/sleekxmpp/xmlstream/filesocket.py
index d4537998..53b83bc7 100644
--- a/sleekxmpp/xmlstream/filesocket.py
+++ b/sleekxmpp/xmlstream/filesocket.py
@@ -13,6 +13,7 @@
"""
from socket import _fileobject
+import errno
import socket
@@ -29,7 +30,13 @@ class FileSocket(_fileobject):
"""Read data from the socket as if it were a file."""
if self._sock is None:
return None
- data = self._sock.recv(size)
+ while True:
+ try:
+ data = self._sock.recv(size)
+ break
+ except socket.error as serr:
+ if serr.errno != errno.EINTR:
+ raise
if data is not None:
return data
diff --git a/sleekxmpp/xmlstream/matcher/__init__.py b/sleekxmpp/xmlstream/matcher/__init__.py
index 1038d1bd..aa74c434 100644
--- a/sleekxmpp/xmlstream/matcher/__init__.py
+++ b/sleekxmpp/xmlstream/matcher/__init__.py
@@ -7,6 +7,7 @@
"""
from sleekxmpp.xmlstream.matcher.id import MatcherId
+from sleekxmpp.xmlstream.matcher.idsender import MatchIDSender
from sleekxmpp.xmlstream.matcher.many import MatchMany
from sleekxmpp.xmlstream.matcher.stanzapath import StanzaPath
from sleekxmpp.xmlstream.matcher.xmlmask import MatchXMLMask
diff --git a/sleekxmpp/xmlstream/matcher/idsender.py b/sleekxmpp/xmlstream/matcher/idsender.py
new file mode 100644
index 00000000..5c2c1f51
--- /dev/null
+++ b/sleekxmpp/xmlstream/matcher/idsender.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+"""
+ sleekxmpp.xmlstream.matcher.id
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ Part of SleekXMPP: The Sleek XMPP Library
+
+ :copyright: (c) 2011 Nathanael C. Fritz
+ :license: MIT, see LICENSE for more details
+"""
+
+from sleekxmpp.xmlstream.matcher.base import MatcherBase
+
+
+class MatchIDSender(MatcherBase):
+
+ """
+ The IDSender matcher selects stanzas that have the same stanza 'id'
+ interface value as the desired ID, and that the 'from' value is one
+ of a set of approved entities that can respond to a request.
+ """
+
+ def match(self, xml):
+ """Compare the given stanza's ``'id'`` attribute to the stored
+ ``id`` value, and verify the sender's JID.
+
+ :param xml: The :class:`~sleekxmpp.xmlstream.stanzabase.ElementBase`
+ stanza to compare against.
+ """
+
+ selfjid = self._criteria['self']
+ peerjid = self._criteria['peer']
+
+ allowed = {}
+ allowed[''] = True
+ allowed[selfjid.bare] = True
+ allowed[selfjid.host] = True
+ allowed[peerjid.full] = True
+ allowed[peerjid.bare] = True
+ allowed[peerjid.host] = True
+
+ _from = xml['from']
+
+ try:
+ return xml['id'] == self._criteria['id'] and allowed[_from]
+ except KeyError:
+ return False
diff --git a/sleekxmpp/xmlstream/matcher/xmlmask.py b/sleekxmpp/xmlstream/matcher/xmlmask.py
index a0568f08..56f728e1 100644
--- a/sleekxmpp/xmlstream/matcher/xmlmask.py
+++ b/sleekxmpp/xmlstream/matcher/xmlmask.py
@@ -14,12 +14,6 @@ from sleekxmpp.xmlstream.stanzabase import ET
from sleekxmpp.xmlstream.matcher.base import MatcherBase
-# Flag indicating if the builtin XPath matcher should be used, which
-# uses namespaces, or a custom matcher that ignores namespaces.
-# Changing this will affect ALL XMLMask matchers.
-IGNORE_NS = False
-
-
log = logging.getLogger(__name__)
@@ -39,19 +33,15 @@ class MatchXMLMask(MatcherBase):
:class:`~sleekxmpp.xmlstream.matcher.stanzapath.StanzaPath`
should be used instead.
- The use of namespaces in the mask comparison is controlled by
- ``IGNORE_NS``. Setting ``IGNORE_NS`` to ``True`` will disable namespace
- based matching for ALL XMLMask matchers.
-
:param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML
object or XML string to use as a mask.
"""
- def __init__(self, criteria):
+ def __init__(self, criteria, default_ns='jabber:client'):
MatcherBase.__init__(self, criteria)
if isinstance(criteria, str):
self._criteria = ET.fromstring(self._criteria)
- self.default_ns = 'jabber:client'
+ self.default_ns = default_ns
def setDefaultNS(self, ns):
"""Set the default namespace to use during comparisons.
@@ -84,8 +74,6 @@ class MatchXMLMask(MatcherBase):
do not have a specified namespace.
Defaults to ``"__no_ns__"``.
"""
- use_ns = not IGNORE_NS
-
if source is None:
# If the element was not found. May happend during recursive calls.
return False
@@ -96,17 +84,10 @@ class MatchXMLMask(MatcherBase):
mask = ET.fromstring(mask)
except ExpatError:
log.warning("Expat error: %s\nIn parsing: %s", '', mask)
- if not use_ns:
- # Compare the element without using namespaces.
- source_tag = source.tag.split('}', 1)[-1]
- mask_tag = mask.tag.split('}', 1)[-1]
- if source_tag != mask_tag:
- return False
- else:
- # Compare the element using namespaces
- mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
- if source.tag not in [mask.tag, mask_ns_tag]:
- return False
+
+ mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
+ if source.tag not in [mask.tag, mask_ns_tag]:
+ return False
# If the mask includes text, compare it.
if mask.text and source.text and \
@@ -122,37 +103,15 @@ class MatchXMLMask(MatcherBase):
# Recursively check subelements.
matched_elements = {}
for subelement in mask:
- if use_ns:
- matched = False
- for other in source.findall(subelement.tag):
- matched_elements[other] = False
- if self._mask_cmp(other, subelement, use_ns):
- if not matched_elements.get(other, False):
- matched_elements[other] = True
- matched = True
- if not matched:
- return False
- else:
- if not self._mask_cmp(self._get_child(source, subelement.tag),
- subelement, use_ns):
- return False
+ matched = False
+ for other in source.findall(subelement.tag):
+ matched_elements[other] = False
+ if self._mask_cmp(other, subelement, use_ns):
+ if not matched_elements.get(other, False):
+ matched_elements[other] = True
+ matched = True
+ if not matched:
+ return False
# Everything matches.
return True
-
- def _get_child(self, xml, tag):
- """Return a child element given its tag, ignoring namespace values.
-
- Returns ``None`` if the child was not found.
-
- :param xml: The :class:`~xml.etree.ElementTree.Element` XML object
- to search for the given child tag.
- :param tag: The name of the subelement to find.
- """
- tag = tag.split('}')[-1]
- try:
- children = [c.tag.split('}')[-1] for c in xml]
- index = children.index(tag)
- except ValueError:
- return None
- return list(xml)[index]
diff --git a/sleekxmpp/xmlstream/matcher/xpath.py b/sleekxmpp/xmlstream/matcher/xpath.py
index 3f03e68e..f3d28429 100644
--- a/sleekxmpp/xmlstream/matcher/xpath.py
+++ b/sleekxmpp/xmlstream/matcher/xpath.py
@@ -9,16 +9,10 @@
:license: MIT, see LICENSE for more details
"""
-from sleekxmpp.xmlstream.stanzabase import ET
+from sleekxmpp.xmlstream.stanzabase import ET, fix_ns
from sleekxmpp.xmlstream.matcher.base import MatcherBase
-# Flag indicating if the builtin XPath matcher should be used, which
-# uses namespaces, or a custom matcher that ignores namespaces.
-# Changing this will affect ALL XPath matchers.
-IGNORE_NS = False
-
-
class MatchXPath(MatcherBase):
"""
@@ -38,6 +32,9 @@ class MatchXPath(MatcherBase):
expressions will be matched without using namespaces.
"""
+ def __init__(self, criteria):
+ self._criteria = fix_ns(criteria)
+
def match(self, xml):
"""
Compare a stanza's XML contents to an XPath expression.
@@ -59,28 +56,4 @@ class MatchXPath(MatcherBase):
x = ET.Element('x')
x.append(xml)
- if not IGNORE_NS:
- # Use builtin, namespace respecting, XPath matcher.
- if x.find(self._criteria) is not None:
- return True
- return False
- else:
- # Remove namespaces from the XPath expression.
- criteria = []
- for ns_block in self._criteria.split('{'):
- criteria.extend(ns_block.split('}')[-1].split('/'))
-
- # Walk the XPath expression.
- xml = x
- for tag in criteria:
- if not tag:
- # Skip empty tag name artifacts from the cleanup phase.
- continue
-
- children = [c.tag.split('}')[-1] for c in xml]
- try:
- index = children.index(tag)
- except ValueError:
- return False
- xml = list(xml)[index]
- return True
+ return x.find(self._criteria) is not None
diff --git a/sleekxmpp/xmlstream/resolver.py b/sleekxmpp/xmlstream/resolver.py
index 394daa64..188e5ac7 100644
--- a/sleekxmpp/xmlstream/resolver.py
+++ b/sleekxmpp/xmlstream/resolver.py
@@ -32,10 +32,10 @@ log = logging.getLogger(__name__)
#: cd dnspython
#: git checkout python3
#: python3 setup.py install
-USE_DNSPYTHON = False
+DNSPYTHON_AVAILABLE = False
try:
import dns.resolver
- USE_DNSPYTHON = True
+ DNSPYTHON_AVAILABLE = True
except ImportError as e:
log.debug("Could not find dnspython package. " + \
"Not all features will be available")
@@ -47,13 +47,13 @@ def default_resolver():
:returns: A :class:`dns.resolver.Resolver` object if dnspython
is available. Otherwise, ``None``.
"""
- if USE_DNSPYTHON:
+ if DNSPYTHON_AVAILABLE:
return dns.resolver.get_default_resolver()
return None
def resolve(host, port=None, service=None, proto='tcp',
- resolver=None, use_ipv6=True):
+ resolver=None, use_ipv6=True, use_dnspython=True):
"""Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol
@@ -77,6 +77,9 @@ def resolve(host, port=None, service=None, proto='tcp',
:param use_ipv6: Optionally control the use of IPv6 in situations
where it is either not available, or performance
is degraded. Defaults to ``True``.
+ :param use_dnspython: Optionally control if dnspython is used to make
+ the DNS queries instead of the built-in DNS
+ library.
:type host: string
:type port: int
@@ -84,14 +87,22 @@ def resolve(host, port=None, service=None, proto='tcp',
:type proto: string
:type resolver: :class:`dns.resolver.Resolver`
:type use_ipv6: bool
+ :type use_dnspython: bool
:return: An iterable of IP address, port pairs in the order
dictated by SRV priorities and weights, if applicable.
"""
+
+ if not use_dnspython:
+ if DNSPYTHON_AVAILABLE:
+ log.debug("DNS: Not using dnspython, but dnspython is installed.")
+ else:
+ log.debug("DNS: Not using dnspython.")
+
if not use_ipv6:
log.debug("DNS: Use of IPv6 has been disabled.")
- if resolver is None and USE_DNSPYTHON:
+ if resolver is None and DNSPYTHON_AVAILABLE and use_dnspython:
resolver = dns.resolver.get_default_resolver()
# An IPv6 literal is allowed to be enclosed in square brackets, but
@@ -113,7 +124,7 @@ def resolve(host, port=None, service=None, proto='tcp',
if hasattr(socket, 'inet_pton'):
ipv6 = socket.inet_pton(socket.AF_INET6, host)
yield (host, host, port)
- except socket.error:
+ except (socket.error, ValueError):
pass
# If no service was provided, then we can just do A/AAAA lookups on the
@@ -122,7 +133,9 @@ def resolve(host, port=None, service=None, proto='tcp',
if not service:
hosts = [(host, port)]
else:
- hosts = get_SRV(host, port, service, proto, resolver=resolver)
+ hosts = get_SRV(host, port, service, proto,
+ resolver=resolver,
+ use_dnspython=use_dnspython)
for host, port in hosts:
results = []
@@ -131,16 +144,18 @@ def resolve(host, port=None, service=None, proto='tcp',
results.append((host, '::1', port))
results.append((host, '127.0.0.1', port))
if use_ipv6:
- for address in get_AAAA(host, resolver=resolver):
+ for address in get_AAAA(host, resolver=resolver,
+ use_dnspython=use_dnspython):
results.append((host, address, port))
- for address in get_A(host, resolver=resolver):
+ for address in get_A(host, resolver=resolver,
+ use_dnspython=use_dnspython):
results.append((host, address, port))
for host, address, port in results:
yield host, address, port
-def get_A(host, resolver=None):
+def get_A(host, resolver=None, use_dnspython=True):
"""Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -148,9 +163,13 @@ def get_A(host, resolver=None):
:param host: The hostname to resolve for A record IPv4 addresses.
:param resolver: Optional DNS resolver object to use for the query.
+ :param use_dnspython: Optionally control if dnspython is used to make
+ the DNS queries instead of the built-in DNS
+ library.
:type host: string
:type resolver: :class:`dns.resolver.Resolver` or ``None``
+ :type use_dnspython: bool
:return: A list of IPv4 literals.
"""
@@ -158,7 +177,7 @@ def get_A(host, resolver=None):
# If not using dnspython, attempt lookup using the OS level
# getaddrinfo() method.
- if resolver is None:
+ if resolver is None or not use_dnspython:
try:
recs = socket.getaddrinfo(host, None, socket.AF_INET,
socket.SOCK_STREAM)
@@ -183,7 +202,7 @@ def get_A(host, resolver=None):
return []
-def get_AAAA(host, resolver=None):
+def get_AAAA(host, resolver=None, use_dnspython=True):
"""Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -191,9 +210,13 @@ def get_AAAA(host, resolver=None):
:param host: The hostname to resolve for AAAA record IPv6 addresses.
:param resolver: Optional DNS resolver object to use for the query.
+ :param use_dnspython: Optionally control if dnspython is used to make
+ the DNS queries instead of the built-in DNS
+ library.
:type host: string
:type resolver: :class:`dns.resolver.Resolver` or ``None``
+ :type use_dnspython: bool
:return: A list of IPv6 literals.
"""
@@ -201,12 +224,15 @@ def get_AAAA(host, resolver=None):
# If not using dnspython, attempt lookup using the OS level
# getaddrinfo() method.
- if resolver is None:
+ if resolver is None or not use_dnspython:
+ if not socket.has_ipv6:
+ log.debug("Unable to query %s for AAAA records: IPv6 is not supported", host)
+ return []
try:
recs = socket.getaddrinfo(host, None, socket.AF_INET6,
socket.SOCK_STREAM)
return [rec[4][0] for rec in recs]
- except socket.gaierror:
+ except (OSError, socket.gaierror):
log.debug("DNS: Error retreiving AAAA address " + \
"info for %s." % host)
return []
@@ -227,7 +253,7 @@ def get_AAAA(host, resolver=None):
return []
-def get_SRV(host, port, service, proto='tcp', resolver=None):
+def get_SRV(host, port, service, proto='tcp', resolver=None, use_dnspython=True):
"""Perform SRV record resolution for a given host.
.. note::
@@ -253,7 +279,7 @@ def get_SRV(host, port, service, proto='tcp', resolver=None):
:return: A list of hostname, port pairs in the order dictacted
by SRV priorities and weights.
"""
- if resolver is None:
+ if resolver is None or not use_dnspython:
log.warning("DNS: dnspython not found. Can not use SRV lookup.")
return [(host, port)]
diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py
index b3e50983..e6fae37a 100644
--- a/sleekxmpp/xmlstream/scheduler.py
+++ b/sleekxmpp/xmlstream/scheduler.py
@@ -20,6 +20,11 @@ import itertools
from sleekxmpp.util import Queue, QueueEmpty
+#: The time in seconds to wait for events from the event queue, and also the
+#: time between checks for the process stop signal.
+WAIT_TIMEOUT = 1.0
+
+
log = logging.getLogger(__name__)
@@ -76,7 +81,7 @@ class Task(object):
"""
if self.qpointer is not None:
self.qpointer.put(('schedule', self.callback,
- self.args, self.name))
+ self.args, self.kwargs, self.name))
else:
self.callback(*self.args, **self.kwargs)
self.reset()
@@ -120,6 +125,10 @@ class Scheduler(object):
#: Lock for accessing the task queue.
self.schedule_lock = threading.RLock()
+ #: The time in seconds to wait for events from the event queue,
+ #: and also the time between checks for the process stop signal.
+ self.wait_timeout = WAIT_TIMEOUT
+
def process(self, threaded=True, daemon=False):
"""Begin accepting and processing scheduled tasks.
@@ -139,24 +148,25 @@ class Scheduler(object):
self.run = True
try:
while self.run and not self.stop.is_set():
- wait = 0.1
updated = False
if self.schedule:
wait = self.schedule[0].next - time.time()
+ else:
+ wait = self.wait_timeout
try:
if wait <= 0.0:
newtask = self.addq.get(False)
else:
- if wait >= 3.0:
- wait = 3.0
newtask = None
- elapsed = 0
- while not self.stop.is_set() and \
+ while self.run and \
+ not self.stop.is_set() and \
newtask is None and \
- elapsed < wait:
- newtask = self.addq.get(True, 0.1)
- elapsed += 0.1
- except QueueEmpty:
+ wait > 0:
+ try:
+ newtask = self.addq.get(True, min(wait, self.wait_timeout))
+ except QueueEmpty: # Nothing to add, nothing to do. Check run flags and continue waiting.
+ wait -= self.wait_timeout
+ except QueueEmpty: # Time to run some tasks, and no new tasks to add.
self.schedule_lock.acquire()
# select only those tasks which are to be executed now
relevant = itertools.takewhile(
@@ -174,11 +184,11 @@ class Scheduler(object):
# only need to resort tasks if a repeated task has
# been kept in the list.
updated = True
- else:
- updated = True
+ else: # Add new task
self.schedule_lock.acquire()
if newtask is not None:
self.schedule.append(newtask)
+ updated = True
finally:
if updated:
self.schedule.sort(key=lambda task: task.next)
diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py
index 122d7eb4..11c8dd67 100644
--- a/sleekxmpp/xmlstream/stanzabase.py
+++ b/sleekxmpp/xmlstream/stanzabase.py
@@ -3,7 +3,7 @@
sleekxmpp.xmlstream.stanzabase
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- This module implements a wrapper layer for XML objects
+ module implements a wrapper layer for XML objects
that allows them to be treated like dictionaries.
Part of SleekXMPP: The Sleek XMPP Library
@@ -19,6 +19,7 @@ import logging
import weakref
from xml.etree import cElementTree as ET
+from sleekxmpp.util import safedict
from sleekxmpp.xmlstream import JID
from sleekxmpp.xmlstream.tostring import tostring
from sleekxmpp.thirdparty import OrderedDict
@@ -141,7 +142,7 @@ def multifactory(stanza, plugin_attrib):
parent.loaded_plugins.remove(plugin_attrib)
try:
parent.xml.remove(self.xml)
- except:
+ except ValueError:
pass
else:
for stanza in list(res):
@@ -192,7 +193,7 @@ def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''):
for element in elements:
if element:
# Skip empty entry artifacts from splitting.
- if propagate_ns:
+ if propagate_ns and element[0] != '*':
tag = '{%s}%s' % (namespace, element)
else:
tag = element
@@ -565,7 +566,10 @@ class ElementBase(object):
values = {}
values['lang'] = self['lang']
for interface in self.interfaces:
- values[interface] = self[interface]
+ if isinstance(self[interface], JID):
+ values[interface] = self[interface].jid
+ else:
+ values[interface] = self[interface]
if interface in self.lang_interfaces:
values['%s|*' % interface] = self['%s|*' % interface]
for plugin, stanza in self.plugins.items():
@@ -596,31 +600,39 @@ class ElementBase(object):
iterable_interfaces = [p.plugin_attrib for \
p in self.plugin_iterables]
+ if 'lang' in values:
+ self['lang'] = values['lang']
+
+ if 'substanzas' in values:
+ # Remove existing substanzas
+ for stanza in self.iterables:
+ try:
+ self.xml.remove(stanza.xml)
+ except ValueError:
+ pass
+ self.iterables = []
+
+ # Add new substanzas
+ for subdict in values['substanzas']:
+ if '__childtag__' in subdict:
+ for subclass in self.plugin_iterables:
+ child_tag = "{%s}%s" % (subclass.namespace,
+ subclass.name)
+ if subdict['__childtag__'] == child_tag:
+ sub = subclass(parent=self)
+ sub.values = subdict
+ self.iterables.append(sub)
+
for interface, value in values.items():
full_interface = interface
interface_lang = ('%s|' % interface).split('|')
interface = interface_lang[0]
lang = interface_lang[1] or self.get_lang()
- if interface == 'substanzas':
- # Remove existing substanzas
- for stanza in self.iterables:
- self.xml.remove(stanza.xml)
- self.iterables = []
-
- # Add new substanzas
- for subdict in value:
- if '__childtag__' in subdict:
- for subclass in self.plugin_iterables:
- child_tag = "{%s}%s" % (subclass.namespace,
- subclass.name)
- if subdict['__childtag__'] == child_tag:
- sub = subclass(parent=self)
- sub.values = subdict
- self.iterables.append(sub)
- break
- elif interface == 'lang':
- self[interface] = value
+ if interface == 'lang':
+ continue
+ elif interface == 'substanzas':
+ continue
elif interface in self.interfaces:
self[full_interface] = value
elif interface in self.plugin_attrib_map:
@@ -668,6 +680,8 @@ class ElementBase(object):
if lang and attrib in self.lang_interfaces:
kwargs['lang'] = lang
+ kwargs = safedict(kwargs)
+
if attrib == 'substanzas':
return self.iterables
elif attrib in self.interfaces or attrib == 'lang':
@@ -744,6 +758,8 @@ class ElementBase(object):
if lang and attrib in self.lang_interfaces:
kwargs['lang'] = lang
+ kwargs = safedict(kwargs)
+
if attrib in self.interfaces or attrib == 'lang':
if value is not None:
set_method = "set_%s" % attrib.lower()
@@ -830,6 +846,8 @@ class ElementBase(object):
if lang and attrib in self.lang_interfaces:
kwargs['lang'] = lang
+ kwargs = safedict(kwargs)
+
if attrib in self.interfaces or attrib == 'lang':
del_method = "del_%s" % attrib.lower()
del_method2 = "del%s" % attrib.title()
@@ -866,7 +884,7 @@ class ElementBase(object):
self.loaded_plugins.remove(attrib)
try:
self.xml.remove(plugin.xml)
- except:
+ except ValueError:
pass
return self
diff --git a/sleekxmpp/xmlstream/tostring.py b/sleekxmpp/xmlstream/tostring.py
index 08d7ad02..c49abd3e 100644
--- a/sleekxmpp/xmlstream/tostring.py
+++ b/sleekxmpp/xmlstream/tostring.py
@@ -24,8 +24,8 @@ if sys.version_info < (3, 0):
XML_NS = 'http://www.w3.org/XML/1998/namespace'
-def tostring(xml=None, xmlns='', stream=None,
- outbuffer='', top_level=False, open_only=False):
+def tostring(xml=None, xmlns='', stream=None, outbuffer='',
+ top_level=False, open_only=False, namespaces=None):
"""Serialize an XML object to a Unicode string.
If an outer xmlns is provided using ``xmlns``, then the current element's
@@ -41,7 +41,8 @@ def tostring(xml=None, xmlns='', stream=None,
during recursive calls.
:param bool top_level: Indicates that the element is the outermost
element.
-
+ :param set namespaces: Track which namespaces are in active use so
+ that new ones can be declared when needed.
:type xml: :py:class:`~xml.etree.ElementTree.Element`
:type stream: :class:`~sleekxmpp.xmlstream.xmlstream.XMLStream`
@@ -63,6 +64,7 @@ def tostring(xml=None, xmlns='', stream=None,
default_ns = ''
stream_ns = ''
use_cdata = False
+
if stream:
default_ns = stream.default_ns
stream_ns = stream.stream_ns
@@ -82,6 +84,7 @@ def tostring(xml=None, xmlns='', stream=None,
output.append(namespace)
# Output escaped attribute values.
+ new_namespaces = set()
for attrib, value in xml.attrib.items():
value = escape(value, use_cdata)
if '}' not in attrib:
@@ -89,14 +92,20 @@ def tostring(xml=None, xmlns='', stream=None,
else:
attrib_ns = attrib.split('}')[0][1:]
attrib = attrib.split('}')[1]
- if stream and attrib_ns in stream.namespace_map:
+ if attrib_ns == XML_NS:
+ output.append(' xml:%s="%s"' % (attrib, value))
+ elif stream and attrib_ns in stream.namespace_map:
mapped_ns = stream.namespace_map[attrib_ns]
if mapped_ns:
- output.append(' %s:%s="%s"' % (mapped_ns,
- attrib,
- value))
- elif attrib_ns == XML_NS:
- output.append(' xml:%s="%s"' % (attrib, value))
+ if namespaces is None:
+ namespaces = set()
+ if attrib_ns not in namespaces:
+ namespaces.add(attrib_ns)
+ new_namespaces.add(attrib_ns)
+ output.append(' xmlns:%s="%s"' % (
+ mapped_ns, attrib_ns))
+ output.append(' %s:%s="%s"' % (
+ mapped_ns, attrib, value))
if open_only:
# Only output the opening tag, regardless of content.
@@ -110,7 +119,8 @@ def tostring(xml=None, xmlns='', stream=None,
output.append(escape(xml.text, use_cdata))
if len(xml):
for child in xml:
- output.append(tostring(child, tag_xmlns, stream))
+ output.append(tostring(child, tag_xmlns, stream,
+ namespaces=namespaces))
output.append("</%s>" % tag_name)
elif xml.text:
# If we only have text content.
@@ -121,6 +131,11 @@ def tostring(xml=None, xmlns='', stream=None,
if xml.tail:
# If there is additional text after the element.
output.append(escape(xml.tail, use_cdata))
+ for ns in new_namespaces:
+ # Remove namespaces introduced in this context. This is necessary
+ # because the namespaces object continues to be shared with other
+ # contexts.
+ namespaces.remove(ns)
return ''.join(output)
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index bea6e88f..f9ec4947 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -26,11 +26,12 @@ import time
import random
import weakref
import uuid
+import errno
from xml.parsers.expat import ExpatError
import sleekxmpp
-from sleekxmpp.util import Queue, QueueEmpty
+from sleekxmpp.util import Queue, QueueEmpty, safedict
from sleekxmpp.thirdparty.statemachine import StateMachine
from sleekxmpp.xmlstream import Scheduler, tostring, cert
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
@@ -49,7 +50,7 @@ RESPONSE_TIMEOUT = 30
#: The time in seconds to wait for events from the event queue, and also the
#: time between checks for the process stop signal.
-WAIT_TIMEOUT = 0.1
+WAIT_TIMEOUT = 1.0
#: The number of threads to use to handle XML stream events. This is not the
#: same as the number of custom event handling threads.
@@ -122,6 +123,11 @@ class XMLStream(object):
#: xmpp.ssl_version = ssl.PROTOCOL_SSLv23
self.ssl_version = ssl.PROTOCOL_TLSv1
+ #: The list of accepted ciphers, in OpenSSL Format.
+ #: It might be useful to override it for improved security
+ #: over the python defaults.
+ self.ciphers = None
+
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
@@ -218,6 +224,11 @@ class XMLStream(object):
#: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True
+ #: If set to ``True``, allow using the ``dnspython`` DNS library
+ #: if available. If set to ``False``, the builtin DNS resolver
+ #: will be used, even if ``dnspython`` is installed.
+ self.use_dnspython = True
+
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
self.use_cdata = False
@@ -280,7 +291,7 @@ class XMLStream(object):
self.event_queue = Queue()
#: A queue of string data to be sent over the stream.
- self.send_queue = Queue()
+ self.send_queue = Queue(maxsize=256)
self.send_queue_lock = threading.Lock()
self.send_lock = threading.RLock()
@@ -449,9 +460,11 @@ class XMLStream(object):
def _connect(self, reattempt=True):
self.scheduler.remove('Session timeout check')
- if self.reconnect_delay is None or not reattempt:
+ if self.reconnect_delay is None:
delay = 1.0
- else:
+ self.reconnect_delay = delay
+
+ if reattempt:
delay = min(self.reconnect_delay * 2, self.reconnect_max_delay)
delay = random.normalvariate(delay, delay * 0.1)
log.debug('Waiting %s seconds before connecting.', delay)
@@ -461,10 +474,10 @@ class XMLStream(object):
time.sleep(0.1)
elapsed += 0.1
except KeyboardInterrupt:
- self.stop.set()
+ self.set_stop()
return False
except SystemExit:
- self.stop.set()
+ self.set_stop()
return False
if self.default_domain:
@@ -507,12 +520,19 @@ class XMLStream(object):
else:
cert_policy = ssl.CERT_REQUIRED
- ssl_socket = ssl.wrap_socket(self.socket,
- certfile=self.certfile,
- keyfile=self.keyfile,
- ca_certs=self.ca_certs,
- cert_reqs=cert_policy,
- do_handshake_on_connect=False)
+ ssl_args = safedict({
+ 'certfile': self.certfile,
+ 'keyfile': self.keyfile,
+ 'ca_certs': self.ca_certs,
+ 'cert_reqs': cert_policy,
+ 'do_handshake_on_connect': False,
+ "ssl_version": self.ssl_version
+ })
+
+ if sys.version_info >= (2, 7):
+ ssl_args['ciphers'] = self.ciphers
+
+ ssl_socket = ssl.wrap_socket(self.socket, **ssl_args)
if hasattr(self.socket, 'socket'):
# We are using a testing socket, so preserve the top
@@ -550,7 +570,7 @@ class XMLStream(object):
cert.verify(self._expected_server_name, self._der_cert)
except cert.CertificateError as err:
if not self.event_handled('ssl_invalid_cert'):
- log.error(err.message)
+ log.error(err)
self.disconnect(send_close=False)
else:
self.event('ssl_invalid_cert',
@@ -559,8 +579,7 @@ class XMLStream(object):
self.set_socket(self.socket, ignore=True)
#this event is where you should set your application state
- self.event("connected", direct=True)
- self.reconnect_delay = 1.0
+ self.event('connected', direct=True)
return True
except (Socket.error, ssl.SSLError) as serr:
error_msg = "Could not connect to %s:%s. Socket Error #%s: %s"
@@ -600,7 +619,7 @@ class XMLStream(object):
headers = '\r\n'.join(headers) + '\r\n\r\n'
try:
- log.debug("Connecting to proxy: %s:%s", address)
+ log.debug("Connecting to proxy: %s:%s", *address)
self.socket.connect(address)
self.send_raw(headers, now=True)
resp = ''
@@ -611,6 +630,7 @@ class XMLStream(object):
lines = resp.split('\r\n')
if '200' not in lines[0]:
self.event('proxy_error', resp)
+ self.event('connection_failed', direct=True)
log.error('Proxy Error: %s', lines[0])
return False
@@ -706,7 +726,7 @@ class XMLStream(object):
self.stream_end_event.set()
if not self.auto_reconnect:
- self.stop.set()
+ self.set_stop()
if self._disconnect_wait_for_threads:
self._wait_for_threads()
@@ -718,12 +738,12 @@ class XMLStream(object):
self.event('socket_error', serr, direct=True)
finally:
#clear your application state
- self.event("disconnected", direct=True)
+ self.event('disconnected', direct=True)
return True
def abort(self):
self.session_started_event.clear()
- self.stop.set()
+ self.set_stop()
if self._disconnect_wait_for_threads:
self._wait_for_threads()
try:
@@ -818,19 +838,26 @@ class XMLStream(object):
to be restarted.
"""
log.info("Negotiating TLS")
- log.info("Using SSL version: %s", str(self.ssl_version))
+ ssl_versions = {3: 'TLS 1.0', 1: 'SSL 3', 2: 'SSL 2/3'}
+ log.info("Using SSL version: %s", ssl_versions[self.ssl_version])
if self.ca_certs is None:
cert_policy = ssl.CERT_NONE
else:
cert_policy = ssl.CERT_REQUIRED
- ssl_socket = ssl.wrap_socket(self.socket,
- certfile=self.certfile,
- keyfile=self.keyfile,
- ssl_version=self.ssl_version,
- do_handshake_on_connect=False,
- ca_certs=self.ca_certs,
- cert_reqs=cert_policy)
+ ssl_args = safedict({
+ 'certfile': self.certfile,
+ 'keyfile': self.keyfile,
+ 'ca_certs': self.ca_certs,
+ 'cert_reqs': cert_policy,
+ 'do_handshake_on_connect': False,
+ "ssl_version": self.ssl_version
+ })
+
+ if sys.version_info >= (2, 7):
+ ssl_args['ciphers'] = self.ciphers
+
+ ssl_socket = ssl.wrap_socket(self.socket, **ssl_args)
if hasattr(self.socket, 'socket'):
# We are using a testing socket, so preserve the top
@@ -859,7 +886,7 @@ class XMLStream(object):
cert.verify(self._expected_server_name, self._der_cert)
except cert.CertificateError as err:
if not self.event_handled('ssl_invalid_cert'):
- log.error(err.message)
+ log.error(err)
self.disconnect(self.auto_reconnect, send_close=False)
else:
self.event('ssl_invalid_cert', pem_cert, direct=True)
@@ -915,12 +942,13 @@ class XMLStream(object):
self.whitespace_keepalive_interval = 300
"""
- self.schedule('Whitespace Keepalive',
- self.whitespace_keepalive_interval,
- self.send_raw,
- args=(' ',),
- kwargs={'now': True},
- repeat=True)
+ if self.whitespace_keepalive:
+ self.schedule('Whitespace Keepalive',
+ self.whitespace_keepalive_interval,
+ self.send_raw,
+ args=(' ',),
+ kwargs={'now': True},
+ repeat=True)
def _remove_schedules(self, event):
"""Remove whitespace keepalive and certificate expiration schedules."""
@@ -1016,9 +1044,13 @@ class XMLStream(object):
# and handler classes here.
if name is None:
- name = 'add_handler_%s' % self.getNewId()
- self.registerHandler(XMLCallback(name, MatchXMLMask(mask), pointer,
- once=disposable, instream=instream))
+ name = 'add_handler_%s' % self.new_id()
+ self.register_handler(
+ XMLCallback(name,
+ MatchXMLMask(mask, self.default_ns),
+ pointer,
+ once=disposable,
+ instream=instream))
def register_handler(self, handler, before=None, after=None):
"""Add a stream event handler that will be executed when a matching
@@ -1059,7 +1091,8 @@ class XMLStream(object):
return resolve(domain, port, service=self.dns_service,
resolver=resolver,
- use_ipv6=self.use_ipv6)
+ use_ipv6=self.use_ipv6,
+ use_dnspython=self.use_dnspython)
def pick_dns_answer(self, domain, port=None):
"""Pick a server and port from DNS answers.
@@ -1120,7 +1153,7 @@ class XMLStream(object):
"""
return len(self.__event_handlers.get(name, []))
- def event(self, name, data={}, direct=False):
+ def event(self, name, data=None, direct=False):
"""Manually trigger a custom event.
:param name: The name of the event to trigger.
@@ -1131,6 +1164,11 @@ class XMLStream(object):
event queue. All event handlers will run in the
same thread.
"""
+ if not data:
+ data = {}
+
+ log.debug("Event triggered: " + name)
+
handlers = self.__event_handlers.get(name, [])
for handler in handlers:
#TODO: Data should not be copied, but should be read only,
@@ -1302,6 +1340,9 @@ class XMLStream(object):
if not self.stop.is_set():
time.sleep(self.ssl_retry_delay)
tries += 1
+ except Socket.error as serr:
+ if serr.errno != errno.EINTR:
+ raise
if count > 1:
log.debug('SENT: %d chunks', count)
except (Socket.error, ssl.SSLError) as serr:
@@ -1316,12 +1357,12 @@ class XMLStream(object):
return True
def _start_thread(self, name, target, track=True):
- self.__active_threads.add(name)
self.__thread[name] = threading.Thread(name=name, target=target)
self.__thread[name].daemon = self._use_daemons
self.__thread[name].start()
if track:
+ self.__active_threads.add(name)
with self.__thread_cond:
self.__thread_count += 1
@@ -1350,6 +1391,13 @@ class XMLStream(object):
if self.__thread_count == 0:
self.__thread_cond.notify()
+ def set_stop(self):
+ self.stop.set()
+
+ # Unlock queues
+ self.event_queue.put(None)
+ self.send_queue.put(None)
+
def _wait_for_threads(self):
with self.__thread_cond:
if self.__thread_count != 0:
@@ -1493,6 +1541,10 @@ class XMLStream(object):
# as handshakes.
self.stream_end_event.clear()
self.start_stream_handler(root)
+
+ # We have a successful stream connection, so reset
+ # exponential backoff for new reconnect attempts.
+ self.reconnect_delay = 1.0
depth += 1
if event == b'end':
depth -= 1
@@ -1618,11 +1670,7 @@ class XMLStream(object):
log.debug("Loading event runner")
try:
while not self.stop.is_set():
- try:
- wait = self.wait_timeout
- event = self.event_queue.get(True, timeout=wait)
- except QueueEmpty:
- event = None
+ event = self.event_queue.get()
if event is None:
continue
@@ -1638,10 +1686,10 @@ class XMLStream(object):
log.exception(error_msg, handler.name)
orig.exception(e)
elif etype == 'schedule':
- name = args[1]
+ name = args[2]
try:
log.debug('Scheduled event: %s: %s', name, args[0])
- handler(*args[0])
+ handler(*args[0], **args[1])
except Exception as e:
log.exception('Error processing scheduled task')
self.exception(e)
@@ -1683,14 +1731,13 @@ class XMLStream(object):
while not self.stop.is_set():
while not self.stop.is_set() and \
not self.session_started_event.is_set():
- self.session_started_event.wait(timeout=0.1)
+ self.session_started_event.wait(timeout=0.1) # Wait for session start
if self.__failed_send_stanza is not None:
data = self.__failed_send_stanza
self.__failed_send_stanza = None
else:
- try:
- data = self.send_queue.get(True, 1)
- except QueueEmpty:
+ data = self.send_queue.get() # Wait for data to send
+ if data is None:
continue
log.debug("SEND: %s", data)
enc_data = data.encode('utf-8')
@@ -1717,6 +1764,9 @@ class XMLStream(object):
if not self.stop.is_set():
time.sleep(self.ssl_retry_delay)
tries += 1
+ except Socket.error as serr:
+ if serr.errno != errno.EINTR:
+ raise
if count > 1:
log.debug('SENT: %d chunks', count)
self.send_queue.task_done()