summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream
diff options
context:
space:
mode:
Diffstat (limited to 'sleekxmpp/xmlstream')
-rw-r--r--sleekxmpp/xmlstream/handler/waiter.py1
-rw-r--r--sleekxmpp/xmlstream/jid.py4
-rw-r--r--sleekxmpp/xmlstream/matcher/stanzapath.py11
-rw-r--r--sleekxmpp/xmlstream/scheduler.py2
-rw-r--r--sleekxmpp/xmlstream/stanzabase.py98
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py193
6 files changed, 232 insertions, 77 deletions
diff --git a/sleekxmpp/xmlstream/handler/waiter.py b/sleekxmpp/xmlstream/handler/waiter.py
index 01ff5d67..899df17c 100644
--- a/sleekxmpp/xmlstream/handler/waiter.py
+++ b/sleekxmpp/xmlstream/handler/waiter.py
@@ -15,7 +15,6 @@ try:
except ImportError:
import Queue as queue
-from sleekxmpp.xmlstream import StanzaBase
from sleekxmpp.xmlstream.handler.base import BaseHandler
diff --git a/sleekxmpp/xmlstream/jid.py b/sleekxmpp/xmlstream/jid.py
index c91c8fb3..281bf4ee 100644
--- a/sleekxmpp/xmlstream/jid.py
+++ b/sleekxmpp/xmlstream/jid.py
@@ -139,3 +139,7 @@ class JID(object):
def __ne__(self, other):
"""Two JIDs are considered unequal if they are not equal."""
return not self == other
+
+ def __hash__(self):
+ """Hash a JID based on the string version of its full JID."""
+ return hash(self.full)
diff --git a/sleekxmpp/xmlstream/matcher/stanzapath.py b/sleekxmpp/xmlstream/matcher/stanzapath.py
index 61c5332c..a4c0fda0 100644
--- a/sleekxmpp/xmlstream/matcher/stanzapath.py
+++ b/sleekxmpp/xmlstream/matcher/stanzapath.py
@@ -10,6 +10,7 @@
"""
from sleekxmpp.xmlstream.matcher.base import MatcherBase
+from sleekxmpp.xmlstream.stanzabase import fix_ns
class StanzaPath(MatcherBase):
@@ -18,8 +19,16 @@ class StanzaPath(MatcherBase):
The StanzaPath matcher selects stanzas that match a given "stanza path",
which is similar to a normal XPath except that it uses the interfaces and
plugins of the stanza instead of the actual, underlying XML.
+
+ :param criteria: Object to compare some aspect of a stanza against.
"""
+ def __init__(self, criteria):
+ self._criteria = fix_ns(criteria, split=True,
+ propagate_ns=False,
+ default_ns='jabber:client')
+ self._raw_criteria = criteria
+
def match(self, stanza):
"""
Compare a stanza against a "stanza path". A stanza path is similar to
@@ -31,4 +40,4 @@ class StanzaPath(MatcherBase):
:param stanza: The :class:`~sleekxmpp.xmlstream.stanzabase.ElementBase`
stanza to compare against.
"""
- return stanza.match(self._criteria)
+ return stanza.match(self._criteria) or stanza.match(self._raw_criteria)
diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py
index 4a6f073f..8ec73164 100644
--- a/sleekxmpp/xmlstream/scheduler.py
+++ b/sleekxmpp/xmlstream/scheduler.py
@@ -161,7 +161,7 @@ class Scheduler(object):
else:
break
for task in cleanup:
- x = self.schedule.pop(self.schedule.index(task))
+ self.schedule.pop(self.schedule.index(task))
else:
updated = True
self.schedule_lock.acquire()
diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py
index 721181a8..96b4f181 100644
--- a/sleekxmpp/xmlstream/stanzabase.py
+++ b/sleekxmpp/xmlstream/stanzabase.py
@@ -14,7 +14,6 @@
import copy
import logging
-import sys
import weakref
from xml.etree import cElementTree as ET
@@ -77,6 +76,49 @@ def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False):
registerStanzaPlugin = register_stanza_plugin
+def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''):
+ """Apply the stanza's namespace to elements in an XPath expression.
+
+ :param string xpath: The XPath expression to fix with namespaces.
+ :param bool split: Indicates if the fixed XPath should be left as a
+ list of element names with namespaces. Defaults to
+ False, which returns a flat string path.
+ :param bool propagate_ns: Overrides propagating parent element
+ namespaces to child elements. Useful if
+ you wish to simply split an XPath that has
+ non-specified namespaces, and child and
+ parent namespaces are known not to always
+ match. Defaults to True.
+ """
+ fixed = []
+ # Split the XPath into a series of blocks, where a block
+ # is started by an element with a namespace.
+ ns_blocks = xpath.split('{')
+ for ns_block in ns_blocks:
+ if '}' in ns_block:
+ # Apply the found namespace to following elements
+ # that do not have namespaces.
+ namespace = ns_block.split('}')[0]
+ elements = ns_block.split('}')[1].split('/')
+ else:
+ # Apply the stanza's namespace to the following
+ # elements since no namespace was provided.
+ namespace = default_ns
+ elements = ns_block.split('/')
+
+ for element in elements:
+ if element:
+ # Skip empty entry artifacts from splitting.
+ if propagate_ns:
+ tag = '{%s}%s' % (namespace, element)
+ else:
+ tag = element
+ fixed.append(tag)
+ if split:
+ return fixed
+ return '/'.join(fixed)
+
+
class ElementBase(object):
"""
@@ -309,6 +351,7 @@ class ElementBase(object):
if self.xml is None:
self.xml = xml
+ last_xml = self.xml
if self.xml is None:
# Generate XML from the stanza definition
for ename in self.name.split('/'):
@@ -345,7 +388,8 @@ class ElementBase(object):
"""
if attrib not in self.plugins:
plugin_class = self.plugin_attrib_map[attrib]
- plugin = plugin_class(parent=self)
+ existing_xml = self.xml.find(plugin_class.tag_name())
+ plugin = plugin_class(parent=self, xml=existing_xml)
self.plugins[attrib] = plugin
if plugin_class in self.plugin_iterables:
self.iterables.append(plugin)
@@ -759,7 +803,7 @@ class ElementBase(object):
may be either a string or a list of element
names with attribute checks.
"""
- if isinstance(xpath, str):
+ if not isinstance(xpath, list):
xpath = self._fix_ns(xpath, split=True, propagate_ns=False)
# Extract the tag name and attribute checks for the first XPath node.
@@ -917,8 +961,9 @@ class ElementBase(object):
Any attribute values will be preserved.
"""
- for child in self.xml.getchildren():
+ for child in list(self.xml):
self.xml.remove(child)
+
for plugin in list(self.plugins.keys()):
del self.plugins[plugin]
return self
@@ -951,46 +996,9 @@ class ElementBase(object):
return self
def _fix_ns(self, xpath, split=False, propagate_ns=True):
- """Apply the stanza's namespace to elements in an XPath expression.
-
- :param string xpath: The XPath expression to fix with namespaces.
- :param bool split: Indicates if the fixed XPath should be left as a
- list of element names with namespaces. Defaults to
- False, which returns a flat string path.
- :param bool propagate_ns: Overrides propagating parent element
- namespaces to child elements. Useful if
- you wish to simply split an XPath that has
- non-specified namespaces, and child and
- parent namespaces are known not to always
- match. Defaults to True.
- """
- fixed = []
- # Split the XPath into a series of blocks, where a block
- # is started by an element with a namespace.
- ns_blocks = xpath.split('{')
- for ns_block in ns_blocks:
- if '}' in ns_block:
- # Apply the found namespace to following elements
- # that do not have namespaces.
- namespace = ns_block.split('}')[0]
- elements = ns_block.split('}')[1].split('/')
- else:
- # Apply the stanza's namespace to the following
- # elements since no namespace was provided.
- namespace = self.namespace
- elements = ns_block.split('/')
-
- for element in elements:
- if element:
- # Skip empty entry artifacts from splitting.
- if propagate_ns:
- tag = '{%s}%s' % (namespace, element)
- else:
- tag = element
- fixed.append(tag)
- if split:
- return fixed
- return '/'.join(fixed)
+ return fix_ns(xpath, split=split,
+ propagate_ns=propagate_ns,
+ default_ns=self.namespace)
def __eq__(self, other):
"""Compare the stanza object with another to test for equality.
@@ -1251,7 +1259,7 @@ class StanzaBase(ElementBase):
stanza sent immediately. Useful for stream
initialization. Defaults to ``False``.
"""
- self.stream.send_raw(self.__str__(), now=now)
+ self.stream.send(self, now=now)
def __copy__(self):
"""Return a copy of the stanza object that does not share the
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index fb9f91bc..22469039 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -24,7 +24,6 @@ import ssl
import sys
import threading
import time
-import types
import random
import weakref
try:
@@ -32,10 +31,12 @@ try:
except ImportError:
import Queue as queue
+from xml.parsers.expat import ExpatError
+
import sleekxmpp
from sleekxmpp.thirdparty.statemachine import StateMachine
from sleekxmpp.xmlstream import Scheduler, tostring
-from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET
+from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
from sleekxmpp.xmlstream.handler import Waiter, XMLCallback
from sleekxmpp.xmlstream.matcher import MatchXMLMask
@@ -80,6 +81,12 @@ SSL_RETRY_MAX = 10
#: Maximum time to delay between connection attempts is one hour.
RECONNECT_MAX_DELAY = 600
+#: Maximum number of attempts to connect to the server before quitting
+#: and raising a 'connect_failed' event. Setting this to ``None`` will
+#: allow infinite reconnection attempts, and using ``0`` will disable
+#: reconnections. Defaults to ``None``.
+RECONNECT_MAX_ATTEMPTS = None
+
log = logging.getLogger(__name__)
@@ -156,6 +163,12 @@ class XMLStream(object):
#: Maximum time to delay between connection attempts is one hour.
self.reconnect_max_delay = RECONNECT_MAX_DELAY
+ #: Maximum number of attempts to connect to the server before
+ #: quitting and raising a 'connect_failed' event. Setting to
+ #: ``None`` allows infinite reattempts, while setting it to ``0``
+ #: will disable reconnection attempts. Defaults to ``None``.
+ self.reconnect_max_attempts = RECONNECT_MAX_ATTEMPTS
+
#: The time in seconds to delay between attempts to resend data
#: after an SSL error.
self.ssl_retry_max = SSL_RETRY_MAX
@@ -254,6 +267,7 @@ class XMLStream(object):
#: A queue of string data to be sent over the stream.
self.send_queue = queue.Queue()
+ self.send_queue_lock = threading.Lock()
#: A :class:`~sleekxmpp.xmlstream.scheduler.Scheduler` instance for
#: executing callbacks in the future based on time delays.
@@ -268,6 +282,7 @@ class XMLStream(object):
self.__handlers = []
self.__event_handlers = {}
self.__event_handlers_lock = threading.Lock()
+ self.__filters = {'in': [], 'out': [], 'out_sync': []}
self._id = 0
self._id_lock = threading.Lock()
@@ -381,13 +396,21 @@ class XMLStream(object):
if use_tls is not None:
self.use_tls = use_tls
+
# Repeatedly attempt to connect until a successful connection
# is established.
+ attempts = self.reconnect_max_attempts
connected = self.state.transition('disconnected', 'connected',
func=self._connect)
while reattempt and not connected and not self.stop.is_set():
connected = self.state.transition('disconnected', 'connected',
func=self._connect)
+ if not connected:
+ if attempts is not None:
+ attempts -= 1
+ if attempts <= 0:
+ self.event('connection_failed', direct=True)
+ return False
return connected
def _connect(self):
@@ -396,9 +419,7 @@ class XMLStream(object):
if self.default_domain:
self.address = self.pick_dns_answer(self.default_domain,
self.address[1])
- self.socket = self.socket_class(Socket.AF_INET, Socket.SOCK_STREAM)
- self.configure_socket()
-
+
if self.reconnect_delay is None:
delay = 1.0
else:
@@ -417,6 +438,27 @@ class XMLStream(object):
self.stop.set()
return False
+ try:
+ # Look for IPv6 addresses, in addition to IPv4
+ for res in Socket.getaddrinfo(self.address[0],
+ int(self.address[1]),
+ 0,
+ Socket.SOCK_STREAM):
+ log.debug("Trying: %s", res[-1])
+ af, sock_type, proto, canonical, sock_addr = res
+ try:
+ self.socket = self.socket_class(af, sock_type, proto)
+ break
+ except Socket.error:
+ log.debug("Could not open IPv%s socket." % proto)
+ except Socket.gaierror:
+ log.warning("Socket could not be opened: no connectivity" + \
+ " or wrong IP versions.")
+ self.reconnect_delay = delay
+ return False
+
+ self.configure_socket()
+
if self.use_proxy:
connected = self._connect_proxy()
if not connected:
@@ -446,6 +488,12 @@ class XMLStream(object):
log.debug("Connecting to %s:%s", *self.address)
self.socket.connect(self.address)
+ if self.use_ssl and self.ssl_support:
+ cert = self.socket.getpeercert(binary_form=True)
+ cert = ssl.DER_cert_to_PEM_cert(cert)
+ log.debug('CERT: %s', cert)
+ self.event('ssl_cert', cert, direct=True)
+
self.set_socket(self.socket, ignore=True)
#this event is where you should set your application state
self.event("connected", direct=True)
@@ -453,7 +501,7 @@ class XMLStream(object):
return True
except Socket.error as serr:
error_msg = "Could not connect to %s:%s. Socket Error #%s: %s"
- self.event('socket_error', serr)
+ self.event('socket_error', serr, direct=True)
log.error(error_msg, self.address[0], self.address[1],
serr.errno, serr.strerror)
self.reconnect_delay = delay
@@ -506,7 +554,7 @@ class XMLStream(object):
return True
except Socket.error as serr:
error_msg = "Could not connect to %s:%s. Socket Error #%s: %s"
- self.event('socket_error', serr)
+ self.event('socket_error', serr, direct=True)
log.error(error_msg, self.address[0], self.address[1],
serr.errno, serr.strerror)
return False
@@ -550,6 +598,7 @@ class XMLStream(object):
:attr:`disconnect_wait`.
"""
self.state.transition('connected', 'disconnected',
+ wait=2.0,
func=self._disconnect, args=(reconnect, wait))
def _disconnect(self, reconnect=False, wait=None):
@@ -577,7 +626,7 @@ class XMLStream(object):
self.socket.close()
self.filesocket.close()
except Socket.error as serr:
- self.event('socket_error', serr)
+ self.event('socket_error', serr, direct=True)
finally:
#clear your application state
self.event("disconnected", direct=True)
@@ -590,6 +639,8 @@ class XMLStream(object):
self.state.transition('connected', 'disconnected', wait=2.0,
func=self._disconnect, args=(True,))
+ attempts = self.reconnect_max_attempts
+
log.debug("connecting...")
connected = self.state.transition('disconnected', 'connected',
wait=2.0, func=self._connect)
@@ -597,6 +648,12 @@ class XMLStream(object):
connected = self.state.transition('disconnected', 'connected',
wait=2.0, func=self._connect)
connected = connected or self.state.ensure('connected')
+ if not connected:
+ if attempts is not None:
+ attempts -= 1
+ if attempts <= 0:
+ self.event('connection_failed', direct=True)
+ return False
return connected
def set_socket(self, socket, ignore=False):
@@ -674,6 +731,12 @@ class XMLStream(object):
else:
self.socket = ssl_socket
self.socket.do_handshake()
+
+ cert = self.socket.getpeercert(binary_form=True)
+ cert = ssl.DER_cert_to_PEM_cert(cert)
+ log.debug('CERT: %s', cert)
+ self.event('ssl_cert', cert, direct=True)
+
self.set_socket(self.socket)
return True
else:
@@ -741,7 +804,29 @@ class XMLStream(object):
stanza objects, but may still be processed using handlers and
matchers.
"""
- del self.__root_stanza[stanza_class]
+ self.__root_stanza.remove(stanza_class)
+
+ def add_filter(self, mode, handler, order=None):
+ """Add a filter for incoming or outgoing stanzas.
+
+ These filters are applied before incoming stanzas are
+ passed to any handlers, and before outgoing stanzas
+ are put in the send queue.
+
+ Each filter must accept a single stanza, and return
+ either a stanza or ``None``. If the filter returns
+ ``None``, then the stanza will be dropped from being
+ processed for events or from being sent.
+
+ :param mode: One of ``'in'`` or ``'out'``.
+ :param handler: The filter function.
+ :param int order: The position to insert the filter in
+ the list of active filters.
+ """
+ if order:
+ self.__filters[mode].insert(order, handler)
+ else:
+ self.__filters[mode].append(handler)
def add_handler(self, mask, pointer, name=None, disposable=False,
threaded=False, filter=False, instream=False):
@@ -808,20 +893,44 @@ class XMLStream(object):
resolver = dns.resolver.get_default_resolver()
self.configure_dns(resolver, domain=domain, port=port)
+ v4_answers = []
+ v6_answers = []
+ answers = []
+
try:
- answers = resolver.query(domain, dns.rdatatype.A)
+ log.debug("Querying A records for %s" % domain)
+ v4_answers = resolver.query(domain, dns.rdatatype.A)
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
log.warning("No A records for %s", domain)
- return [((domain, port), 0, 0)]
+ v4_answers = [((domain, port), 0, 0)]
except dns.exception.Timeout:
log.warning("DNS resolution timed out " + \
"for A record of %s", domain)
- return [((domain, port), 0, 0)]
+ v4_answers = [((domain, port), 0, 0)]
else:
- return [((ans.address, port), 0, 0) for ans in answers]
+ for ans in v4_answers:
+ log.debug("Found A record: %s", ans.address)
+ answers.append(((ans.address, port), 0, 0))
+
+ try:
+ log.debug("Querying AAAA records for %s" % domain)
+ v6_answers = resolver.query(domain, dns.rdatatype.AAAA)
+ except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+ log.warning("No AAAA records for %s", domain)
+ v6_answers = [((domain, port), 0, 0)]
+ except dns.exception.Timeout:
+ log.warning("DNS resolution timed out " + \
+ "for AAAA record of %s", domain)
+ v6_answers = [((domain, port), 0, 0)]
+ else:
+ for ans in v6_answers:
+ log.debug("Found AAAA record: %s", ans.address)
+ answers.append(((ans.address, port), 0, 0))
+
+ return answers
else:
log.warning("dnspython is not installed -- " + \
- "relying on OS A record resolution")
+ "relying on OS A/AAAA record resolution")
self.configure_dns(None, domain=domain, port=port)
return [((domain, port), 0, 0)]
@@ -850,6 +959,7 @@ class XMLStream(object):
items = [x for x in addresses.keys()]
items.sort()
+ address = (domain, port)
picked = random.randint(0, intmax)
for item in items:
if picked <= item:
@@ -857,8 +967,8 @@ class XMLStream(object):
break
for idx, answer in enumerate(self.dns_answers):
if self.dns_answers[0] == address:
+ self.dns_answers.pop(idx)
break
- self.dns_answers.pop(idx)
log.debug("Trying to connect to %s:%s", *address)
return address
@@ -971,7 +1081,7 @@ class XMLStream(object):
"""
return xml
- def send(self, data, mask=None, timeout=None, now=False):
+ def send(self, data, mask=None, timeout=None, now=False, use_filters=True):
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
May optionally block until an expected response is received.
@@ -989,18 +1099,40 @@ class XMLStream(object):
sending the stanza immediately. Useful mainly
for stream initialization stanzas.
Defaults to ``False``.
+ :param bool use_filters: Indicates if outgoing filters should be
+ applied to the given stanza data. Disabling
+ filters is useful when resending stanzas.
+ Defaults to ``True``.
"""
if timeout is None:
timeout = self.response_timeout
if hasattr(mask, 'xml'):
mask = mask.xml
- data = str(data)
+
+ if isinstance(data, ElementBase):
+ if use_filters:
+ for filter in self.__filters['out']:
+ data = filter(data)
+ if data is None:
+ return
+
if mask is not None:
log.warning("Use of send mask waiters is deprecated.")
wait_for = Waiter("SendWait_%s" % self.new_id(),
MatchXMLMask(mask))
self.register_handler(wait_for)
- self.send_raw(data, now)
+
+ if isinstance(data, ElementBase):
+ with self.send_queue_lock:
+ if use_filters:
+ for filter in self.__filters['out_sync']:
+ data = filter(data)
+ if data is None:
+ return
+ str_data = str(data)
+ self.send_raw(str_data, now)
+ else:
+ self.send_raw(data, now)
if mask is not None:
return wait_for.wait(timeout)
@@ -1061,7 +1193,7 @@ class XMLStream(object):
if count > 1:
log.debug('SENT: %d chunks', count)
except Socket.error as serr:
- self.event('socket_error', serr)
+ self.event('socket_error', serr, direct=True)
log.warning("Failed to send %s", data)
if reconnect is None:
reconnect = self.auto_reconnect
@@ -1157,12 +1289,11 @@ class XMLStream(object):
except SystemExit:
log.debug("SystemExit in _process")
shutdown = True
- except SyntaxError as e:
+ except (SyntaxError, ExpatError) as e:
log.error("Error reading from XML stream.")
- shutdown = True
self.exception(e)
except Socket.error as serr:
- self.event('socket_error', serr)
+ self.event('socket_error', serr, direct=True)
log.exception('Socket Error')
except Exception as e:
if not self.stop.is_set():
@@ -1246,8 +1377,6 @@ class XMLStream(object):
:param xml: The :class:`~sleekxmpp.xmlstream.stanzabase.ElementBase`
stanza to analyze.
"""
- log.debug("RECV: %s", tostring(xml, xmlns=self.default_ns,
- stream=self))
# Apply any preprocessing filters.
xml = self.incoming_filter(xml)
@@ -1255,6 +1384,14 @@ class XMLStream(object):
# stanza type applies, a generic StanzaBase stanza will be used.
stanza = self._build_stanza(xml)
+ for filter in self.__filters['in']:
+ if stanza is not None:
+ stanza = filter(stanza)
+ if stanza is None:
+ return
+
+ log.debug("RECV: %s", stanza)
+
# Match the stanza against registered handlers. Handlers marked
# to run "in stream" will be executed immediately; the rest will
# be queued.
@@ -1371,7 +1508,7 @@ class XMLStream(object):
"""Extract stanzas from the send queue and send them on the stream."""
try:
while not self.stop.is_set():
- while not self.stop.is_set and \
+ while not self.stop.is_set() and \
not self.session_started_event.is_set():
self.session_started_event.wait(timeout=1)
if self.__failed_send_stanza is not None:
@@ -1398,9 +1535,7 @@ class XMLStream(object):
log.debug('SSL error - max retries reached')
self.exception(serr)
log.warning("Failed to send %s", data)
- if reconnect is None:
- reconnect = self.auto_reconnect
- self.disconnect(reconnect)
+ self.disconnect(self.auto_reconnect)
log.warning('SSL write error - reattempting')
time.sleep(self.ssl_retry_delay)
tries += 1
@@ -1408,7 +1543,7 @@ class XMLStream(object):
log.debug('SENT: %d chunks', count)
self.send_queue.task_done()
except Socket.error as serr:
- self.event('socket_error', serr)
+ self.event('socket_error', serr, direct=True)
log.warning("Failed to send %s", data)
self.__failed_send_stanza = data
self.disconnect(self.auto_reconnect)