diff options
Diffstat (limited to 'sleekxmpp/xmlstream')
-rw-r--r-- | sleekxmpp/xmlstream/handler/waiter.py | 1 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/jid.py | 4 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/matcher/stanzapath.py | 11 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/scheduler.py | 2 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/stanzabase.py | 98 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 212 |
6 files changed, 244 insertions, 84 deletions
diff --git a/sleekxmpp/xmlstream/handler/waiter.py b/sleekxmpp/xmlstream/handler/waiter.py index 01ff5d67..899df17c 100644 --- a/sleekxmpp/xmlstream/handler/waiter.py +++ b/sleekxmpp/xmlstream/handler/waiter.py @@ -15,7 +15,6 @@ try: except ImportError: import Queue as queue -from sleekxmpp.xmlstream import StanzaBase from sleekxmpp.xmlstream.handler.base import BaseHandler diff --git a/sleekxmpp/xmlstream/jid.py b/sleekxmpp/xmlstream/jid.py index c91c8fb3..281bf4ee 100644 --- a/sleekxmpp/xmlstream/jid.py +++ b/sleekxmpp/xmlstream/jid.py @@ -139,3 +139,7 @@ class JID(object): def __ne__(self, other): """Two JIDs are considered unequal if they are not equal.""" return not self == other + + def __hash__(self): + """Hash a JID based on the string version of its full JID.""" + return hash(self.full) diff --git a/sleekxmpp/xmlstream/matcher/stanzapath.py b/sleekxmpp/xmlstream/matcher/stanzapath.py index 61c5332c..a4c0fda0 100644 --- a/sleekxmpp/xmlstream/matcher/stanzapath.py +++ b/sleekxmpp/xmlstream/matcher/stanzapath.py @@ -10,6 +10,7 @@ """ from sleekxmpp.xmlstream.matcher.base import MatcherBase +from sleekxmpp.xmlstream.stanzabase import fix_ns class StanzaPath(MatcherBase): @@ -18,8 +19,16 @@ class StanzaPath(MatcherBase): The StanzaPath matcher selects stanzas that match a given "stanza path", which is similar to a normal XPath except that it uses the interfaces and plugins of the stanza instead of the actual, underlying XML. + + :param criteria: Object to compare some aspect of a stanza against. """ + def __init__(self, criteria): + self._criteria = fix_ns(criteria, split=True, + propagate_ns=False, + default_ns='jabber:client') + self._raw_criteria = criteria + def match(self, stanza): """ Compare a stanza against a "stanza path". A stanza path is similar to @@ -31,4 +40,4 @@ class StanzaPath(MatcherBase): :param stanza: The :class:`~sleekxmpp.xmlstream.stanzabase.ElementBase` stanza to compare against. """ - return stanza.match(self._criteria) + return stanza.match(self._criteria) or stanza.match(self._raw_criteria) diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py index 4a6f073f..8ec73164 100644 --- a/sleekxmpp/xmlstream/scheduler.py +++ b/sleekxmpp/xmlstream/scheduler.py @@ -161,7 +161,7 @@ class Scheduler(object): else: break for task in cleanup: - x = self.schedule.pop(self.schedule.index(task)) + self.schedule.pop(self.schedule.index(task)) else: updated = True self.schedule_lock.acquire() diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py index 721181a8..96b4f181 100644 --- a/sleekxmpp/xmlstream/stanzabase.py +++ b/sleekxmpp/xmlstream/stanzabase.py @@ -14,7 +14,6 @@ import copy import logging -import sys import weakref from xml.etree import cElementTree as ET @@ -77,6 +76,49 @@ def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False): registerStanzaPlugin = register_stanza_plugin +def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''): + """Apply the stanza's namespace to elements in an XPath expression. + + :param string xpath: The XPath expression to fix with namespaces. + :param bool split: Indicates if the fixed XPath should be left as a + list of element names with namespaces. Defaults to + False, which returns a flat string path. + :param bool propagate_ns: Overrides propagating parent element + namespaces to child elements. Useful if + you wish to simply split an XPath that has + non-specified namespaces, and child and + parent namespaces are known not to always + match. Defaults to True. + """ + fixed = [] + # Split the XPath into a series of blocks, where a block + # is started by an element with a namespace. + ns_blocks = xpath.split('{') + for ns_block in ns_blocks: + if '}' in ns_block: + # Apply the found namespace to following elements + # that do not have namespaces. + namespace = ns_block.split('}')[0] + elements = ns_block.split('}')[1].split('/') + else: + # Apply the stanza's namespace to the following + # elements since no namespace was provided. + namespace = default_ns + elements = ns_block.split('/') + + for element in elements: + if element: + # Skip empty entry artifacts from splitting. + if propagate_ns: + tag = '{%s}%s' % (namespace, element) + else: + tag = element + fixed.append(tag) + if split: + return fixed + return '/'.join(fixed) + + class ElementBase(object): """ @@ -309,6 +351,7 @@ class ElementBase(object): if self.xml is None: self.xml = xml + last_xml = self.xml if self.xml is None: # Generate XML from the stanza definition for ename in self.name.split('/'): @@ -345,7 +388,8 @@ class ElementBase(object): """ if attrib not in self.plugins: plugin_class = self.plugin_attrib_map[attrib] - plugin = plugin_class(parent=self) + existing_xml = self.xml.find(plugin_class.tag_name()) + plugin = plugin_class(parent=self, xml=existing_xml) self.plugins[attrib] = plugin if plugin_class in self.plugin_iterables: self.iterables.append(plugin) @@ -759,7 +803,7 @@ class ElementBase(object): may be either a string or a list of element names with attribute checks. """ - if isinstance(xpath, str): + if not isinstance(xpath, list): xpath = self._fix_ns(xpath, split=True, propagate_ns=False) # Extract the tag name and attribute checks for the first XPath node. @@ -917,8 +961,9 @@ class ElementBase(object): Any attribute values will be preserved. """ - for child in self.xml.getchildren(): + for child in list(self.xml): self.xml.remove(child) + for plugin in list(self.plugins.keys()): del self.plugins[plugin] return self @@ -951,46 +996,9 @@ class ElementBase(object): return self def _fix_ns(self, xpath, split=False, propagate_ns=True): - """Apply the stanza's namespace to elements in an XPath expression. - - :param string xpath: The XPath expression to fix with namespaces. - :param bool split: Indicates if the fixed XPath should be left as a - list of element names with namespaces. Defaults to - False, which returns a flat string path. - :param bool propagate_ns: Overrides propagating parent element - namespaces to child elements. Useful if - you wish to simply split an XPath that has - non-specified namespaces, and child and - parent namespaces are known not to always - match. Defaults to True. - """ - fixed = [] - # Split the XPath into a series of blocks, where a block - # is started by an element with a namespace. - ns_blocks = xpath.split('{') - for ns_block in ns_blocks: - if '}' in ns_block: - # Apply the found namespace to following elements - # that do not have namespaces. - namespace = ns_block.split('}')[0] - elements = ns_block.split('}')[1].split('/') - else: - # Apply the stanza's namespace to the following - # elements since no namespace was provided. - namespace = self.namespace - elements = ns_block.split('/') - - for element in elements: - if element: - # Skip empty entry artifacts from splitting. - if propagate_ns: - tag = '{%s}%s' % (namespace, element) - else: - tag = element - fixed.append(tag) - if split: - return fixed - return '/'.join(fixed) + return fix_ns(xpath, split=split, + propagate_ns=propagate_ns, + default_ns=self.namespace) def __eq__(self, other): """Compare the stanza object with another to test for equality. @@ -1251,7 +1259,7 @@ class StanzaBase(ElementBase): stanza sent immediately. Useful for stream initialization. Defaults to ``False``. """ - self.stream.send_raw(self.__str__(), now=now) + self.stream.send(self, now=now) def __copy__(self): """Return a copy of the stanza object that does not share the diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index fb9f91bc..6ba82c37 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -24,7 +24,6 @@ import ssl import sys import threading import time -import types import random import weakref try: @@ -32,10 +31,12 @@ try: except ImportError: import Queue as queue +from xml.parsers.expat import ExpatError + import sleekxmpp from sleekxmpp.thirdparty.statemachine import StateMachine from sleekxmpp.xmlstream import Scheduler, tostring -from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET +from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase from sleekxmpp.xmlstream.handler import Waiter, XMLCallback from sleekxmpp.xmlstream.matcher import MatchXMLMask @@ -80,6 +81,12 @@ SSL_RETRY_MAX = 10 #: Maximum time to delay between connection attempts is one hour. RECONNECT_MAX_DELAY = 600 +#: Maximum number of attempts to connect to the server before quitting +#: and raising a 'connect_failed' event. Setting this to ``None`` will +#: allow infinite reconnection attempts, and using ``0`` will disable +#: reconnections. Defaults to ``None``. +RECONNECT_MAX_ATTEMPTS = None + log = logging.getLogger(__name__) @@ -156,6 +163,12 @@ class XMLStream(object): #: Maximum time to delay between connection attempts is one hour. self.reconnect_max_delay = RECONNECT_MAX_DELAY + #: Maximum number of attempts to connect to the server before + #: quitting and raising a 'connect_failed' event. Setting to + #: ``None`` allows infinite reattempts, while setting it to ``0`` + #: will disable reconnection attempts. Defaults to ``None``. + self.reconnect_max_attempts = RECONNECT_MAX_ATTEMPTS + #: The time in seconds to delay between attempts to resend data #: after an SSL error. self.ssl_retry_max = SSL_RETRY_MAX @@ -254,6 +267,7 @@ class XMLStream(object): #: A queue of string data to be sent over the stream. self.send_queue = queue.Queue() + self.send_queue_lock = threading.Lock() #: A :class:`~sleekxmpp.xmlstream.scheduler.Scheduler` instance for #: executing callbacks in the future based on time delays. @@ -268,6 +282,7 @@ class XMLStream(object): self.__handlers = [] self.__event_handlers = {} self.__event_handlers_lock = threading.Lock() + self.__filters = {'in': [], 'out': [], 'out_sync': []} self._id = 0 self._id_lock = threading.Lock() @@ -355,8 +370,10 @@ class XMLStream(object): use_tls=True, reattempt=True): """Create a new socket and connect to the server. - Setting ``reattempt`` to ``True`` will cause connection attempts to - be made every second until a successful connection is established. + Setting ``reattempt`` to ``True`` will cause connection + attempts to be made with an exponential backoff delay (max of + :attr:`reconnect_max_delay` which defaults to 10 minute) until a + successful connection is established. :param host: The name of the desired server for the connection. :param port: Port to connect to on the server. @@ -381,25 +398,31 @@ class XMLStream(object): if use_tls is not None: self.use_tls = use_tls + # Repeatedly attempt to connect until a successful connection # is established. + attempts = self.reconnect_max_attempts connected = self.state.transition('disconnected', 'connected', - func=self._connect) + func=self._connect, args=(reattempt,)) while reattempt and not connected and not self.stop.is_set(): connected = self.state.transition('disconnected', 'connected', func=self._connect) + if not connected: + if attempts is not None: + attempts -= 1 + if attempts <= 0: + self.event('connection_failed', direct=True) + return False return connected - def _connect(self): + def _connect(self, reattempt=True): self.scheduler.remove('Session timeout check') self.stop.clear() if self.default_domain: self.address = self.pick_dns_answer(self.default_domain, self.address[1]) - self.socket = self.socket_class(Socket.AF_INET, Socket.SOCK_STREAM) - self.configure_socket() - - if self.reconnect_delay is None: + + if self.reconnect_delay is None or not reattempt: delay = 1.0 else: delay = min(self.reconnect_delay * 2, self.reconnect_max_delay) @@ -417,10 +440,33 @@ class XMLStream(object): self.stop.set() return False + try: + # Look for IPv6 addresses, in addition to IPv4 + for res in Socket.getaddrinfo(self.address[0], + int(self.address[1]), + 0, + Socket.SOCK_STREAM): + log.debug("Trying: %s", res[-1]) + af, sock_type, proto, canonical, sock_addr = res + try: + self.socket = self.socket_class(af, sock_type, proto) + break + except Socket.error: + log.debug("Could not open IPv%s socket." % proto) + except Socket.gaierror: + log.warning("Socket could not be opened: no connectivity" + \ + " or wrong IP versions.") + if reattempt: + self.reconnect_delay = delay + return False + + self.configure_socket() + if self.use_proxy: connected = self._connect_proxy() if not connected: - self.reconnect_delay = delay + if reattempt: + self.reconnect_delay = delay return False if self.use_ssl and self.ssl_support: @@ -446,6 +492,12 @@ class XMLStream(object): log.debug("Connecting to %s:%s", *self.address) self.socket.connect(self.address) + if self.use_ssl and self.ssl_support: + cert = self.socket.getpeercert(binary_form=True) + cert = ssl.DER_cert_to_PEM_cert(cert) + log.debug('CERT: %s', cert) + self.event('ssl_cert', cert, direct=True) + self.set_socket(self.socket, ignore=True) #this event is where you should set your application state self.event("connected", direct=True) @@ -453,10 +505,11 @@ class XMLStream(object): return True except Socket.error as serr: error_msg = "Could not connect to %s:%s. Socket Error #%s: %s" - self.event('socket_error', serr) + self.event('socket_error', serr, direct=True) log.error(error_msg, self.address[0], self.address[1], serr.errno, serr.strerror) - self.reconnect_delay = delay + if reattempt: + self.reconnect_delay = delay return False def _connect_proxy(self): @@ -506,7 +559,7 @@ class XMLStream(object): return True except Socket.error as serr: error_msg = "Could not connect to %s:%s. Socket Error #%s: %s" - self.event('socket_error', serr) + self.event('socket_error', serr, direct=True) log.error(error_msg, self.address[0], self.address[1], serr.errno, serr.strerror) return False @@ -550,6 +603,7 @@ class XMLStream(object): :attr:`disconnect_wait`. """ self.state.transition('connected', 'disconnected', + wait=2.0, func=self._disconnect, args=(reconnect, wait)) def _disconnect(self, reconnect=False, wait=None): @@ -577,7 +631,7 @@ class XMLStream(object): self.socket.close() self.filesocket.close() except Socket.error as serr: - self.event('socket_error', serr) + self.event('socket_error', serr, direct=True) finally: #clear your application state self.event("disconnected", direct=True) @@ -590,6 +644,8 @@ class XMLStream(object): self.state.transition('connected', 'disconnected', wait=2.0, func=self._disconnect, args=(True,)) + attempts = self.reconnect_max_attempts + log.debug("connecting...") connected = self.state.transition('disconnected', 'connected', wait=2.0, func=self._connect) @@ -597,6 +653,12 @@ class XMLStream(object): connected = self.state.transition('disconnected', 'connected', wait=2.0, func=self._connect) connected = connected or self.state.ensure('connected') + if not connected: + if attempts is not None: + attempts -= 1 + if attempts <= 0: + self.event('connection_failed', direct=True) + return False return connected def set_socket(self, socket, ignore=False): @@ -674,6 +736,12 @@ class XMLStream(object): else: self.socket = ssl_socket self.socket.do_handshake() + + cert = self.socket.getpeercert(binary_form=True) + cert = ssl.DER_cert_to_PEM_cert(cert) + log.debug('CERT: %s', cert) + self.event('ssl_cert', cert, direct=True) + self.set_socket(self.socket) return True else: @@ -741,7 +809,29 @@ class XMLStream(object): stanza objects, but may still be processed using handlers and matchers. """ - del self.__root_stanza[stanza_class] + self.__root_stanza.remove(stanza_class) + + def add_filter(self, mode, handler, order=None): + """Add a filter for incoming or outgoing stanzas. + + These filters are applied before incoming stanzas are + passed to any handlers, and before outgoing stanzas + are put in the send queue. + + Each filter must accept a single stanza, and return + either a stanza or ``None``. If the filter returns + ``None``, then the stanza will be dropped from being + processed for events or from being sent. + + :param mode: One of ``'in'`` or ``'out'``. + :param handler: The filter function. + :param int order: The position to insert the filter in + the list of active filters. + """ + if order: + self.__filters[mode].insert(order, handler) + else: + self.__filters[mode].append(handler) def add_handler(self, mask, pointer, name=None, disposable=False, threaded=False, filter=False, instream=False): @@ -808,20 +898,44 @@ class XMLStream(object): resolver = dns.resolver.get_default_resolver() self.configure_dns(resolver, domain=domain, port=port) + v4_answers = [] + v6_answers = [] + answers = [] + try: - answers = resolver.query(domain, dns.rdatatype.A) + log.debug("Querying A records for %s" % domain) + v4_answers = resolver.query(domain, dns.rdatatype.A) except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): log.warning("No A records for %s", domain) - return [((domain, port), 0, 0)] + v4_answers = [((domain, port), 0, 0)] except dns.exception.Timeout: log.warning("DNS resolution timed out " + \ "for A record of %s", domain) - return [((domain, port), 0, 0)] + v4_answers = [((domain, port), 0, 0)] + else: + for ans in v4_answers: + log.debug("Found A record: %s", ans.address) + answers.append(((ans.address, port), 0, 0)) + + try: + log.debug("Querying AAAA records for %s" % domain) + v6_answers = resolver.query(domain, dns.rdatatype.AAAA) + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): + log.warning("No AAAA records for %s", domain) + v6_answers = [((domain, port), 0, 0)] + except dns.exception.Timeout: + log.warning("DNS resolution timed out " + \ + "for AAAA record of %s", domain) + v6_answers = [((domain, port), 0, 0)] else: - return [((ans.address, port), 0, 0) for ans in answers] + for ans in v6_answers: + log.debug("Found AAAA record: %s", ans.address) + answers.append(((ans.address, port), 0, 0)) + + return answers else: log.warning("dnspython is not installed -- " + \ - "relying on OS A record resolution") + "relying on OS A/AAAA record resolution") self.configure_dns(None, domain=domain, port=port) return [((domain, port), 0, 0)] @@ -850,6 +964,7 @@ class XMLStream(object): items = [x for x in addresses.keys()] items.sort() + address = (domain, port) picked = random.randint(0, intmax) for item in items: if picked <= item: @@ -857,8 +972,8 @@ class XMLStream(object): break for idx, answer in enumerate(self.dns_answers): if self.dns_answers[0] == address: + self.dns_answers.pop(idx) break - self.dns_answers.pop(idx) log.debug("Trying to connect to %s:%s", *address) return address @@ -971,7 +1086,7 @@ class XMLStream(object): """ return xml - def send(self, data, mask=None, timeout=None, now=False): + def send(self, data, mask=None, timeout=None, now=False, use_filters=True): """A wrapper for :meth:`send_raw()` for sending stanza objects. May optionally block until an expected response is received. @@ -989,18 +1104,40 @@ class XMLStream(object): sending the stanza immediately. Useful mainly for stream initialization stanzas. Defaults to ``False``. + :param bool use_filters: Indicates if outgoing filters should be + applied to the given stanza data. Disabling + filters is useful when resending stanzas. + Defaults to ``True``. """ if timeout is None: timeout = self.response_timeout if hasattr(mask, 'xml'): mask = mask.xml - data = str(data) + + if isinstance(data, ElementBase): + if use_filters: + for filter in self.__filters['out']: + data = filter(data) + if data is None: + return + if mask is not None: log.warning("Use of send mask waiters is deprecated.") wait_for = Waiter("SendWait_%s" % self.new_id(), MatchXMLMask(mask)) self.register_handler(wait_for) - self.send_raw(data, now) + + if isinstance(data, ElementBase): + with self.send_queue_lock: + if use_filters: + for filter in self.__filters['out_sync']: + data = filter(data) + if data is None: + return + str_data = str(data) + self.send_raw(str_data, now) + else: + self.send_raw(data, now) if mask is not None: return wait_for.wait(timeout) @@ -1061,7 +1198,7 @@ class XMLStream(object): if count > 1: log.debug('SENT: %d chunks', count) except Socket.error as serr: - self.event('socket_error', serr) + self.event('socket_error', serr, direct=True) log.warning("Failed to send %s", data) if reconnect is None: reconnect = self.auto_reconnect @@ -1157,12 +1294,11 @@ class XMLStream(object): except SystemExit: log.debug("SystemExit in _process") shutdown = True - except SyntaxError as e: + except (SyntaxError, ExpatError) as e: log.error("Error reading from XML stream.") - shutdown = True self.exception(e) except Socket.error as serr: - self.event('socket_error', serr) + self.event('socket_error', serr, direct=True) log.exception('Socket Error') except Exception as e: if not self.stop.is_set(): @@ -1246,8 +1382,6 @@ class XMLStream(object): :param xml: The :class:`~sleekxmpp.xmlstream.stanzabase.ElementBase` stanza to analyze. """ - log.debug("RECV: %s", tostring(xml, xmlns=self.default_ns, - stream=self)) # Apply any preprocessing filters. xml = self.incoming_filter(xml) @@ -1255,6 +1389,14 @@ class XMLStream(object): # stanza type applies, a generic StanzaBase stanza will be used. stanza = self._build_stanza(xml) + for filter in self.__filters['in']: + if stanza is not None: + stanza = filter(stanza) + if stanza is None: + return + + log.debug("RECV: %s", stanza) + # Match the stanza against registered handlers. Handlers marked # to run "in stream" will be executed immediately; the rest will # be queued. @@ -1371,7 +1513,7 @@ class XMLStream(object): """Extract stanzas from the send queue and send them on the stream.""" try: while not self.stop.is_set(): - while not self.stop.is_set and \ + while not self.stop.is_set() and \ not self.session_started_event.is_set(): self.session_started_event.wait(timeout=1) if self.__failed_send_stanza is not None: @@ -1398,9 +1540,7 @@ class XMLStream(object): log.debug('SSL error - max retries reached') self.exception(serr) log.warning("Failed to send %s", data) - if reconnect is None: - reconnect = self.auto_reconnect - self.disconnect(reconnect) + self.disconnect(self.auto_reconnect) log.warning('SSL write error - reattempting') time.sleep(self.ssl_retry_delay) tries += 1 @@ -1408,7 +1548,7 @@ class XMLStream(object): log.debug('SENT: %d chunks', count) self.send_queue.task_done() except Socket.error as serr: - self.event('socket_error', serr) + self.event('socket_error', serr, direct=True) log.warning("Failed to send %s", data) self.__failed_send_stanza = data self.disconnect(self.auto_reconnect) |