summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream
diff options
context:
space:
mode:
Diffstat (limited to 'sleekxmpp/xmlstream')
-rw-r--r--sleekxmpp/xmlstream/resolver.py287
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py124
2 files changed, 314 insertions, 97 deletions
diff --git a/sleekxmpp/xmlstream/resolver.py b/sleekxmpp/xmlstream/resolver.py
new file mode 100644
index 00000000..ecb76519
--- /dev/null
+++ b/sleekxmpp/xmlstream/resolver.py
@@ -0,0 +1,287 @@
+# -*- encoding: utf-8 -*-
+
+"""
+ sleekxmpp.xmlstream.dns
+ ~~~~~~~~~~~~~~~~~~~~~~~
+
+ :copyright: (c) 2012 Nathanael C. Fritz
+ :license: MIT, see LICENSE for more details
+"""
+
+import socket
+import logging
+import random
+
+
+log = logging.getLogger(__name__)
+
+
+#: Global flag indicating the availability of the ``dnspython`` package.
+#: Installing ``dnspython`` can be done via:
+#:
+#: .. code-block:: sh
+#:
+#: pip install dnspython
+#:
+#: For Python3, installation may require installing from source using
+#: the ``python3`` branch:
+#:
+#: .. code-block:: sh
+#:
+#: git clone http://github.com/rthalley/dnspython
+#: cd dnspython
+#: git checkout python3
+#: python3 setup.py install
+USE_DNSPYTHON = False
+try:
+ import dns.resolver
+ USE_DNSPYTHON = True
+except ImportError as e:
+ log.debug("Could not find dnspython package. " + \
+ "Not all features will be available")
+
+
+def default_resolver():
+ """Return a basic DNS resolver object.
+
+ :returns: A :class:`dns.resolver.Resolver` object if dnspython
+ is available. Otherwise, ``None``.
+ """
+ if USE_DNSPYTHON:
+ return dns.resolver.get_default_resolver()
+ return None
+
+
+def resolve(host, port=None, service=None, proto='tcp', resolver=None):
+ """Peform DNS resolution for a given hostname.
+
+ Resolution may perform SRV record lookups if a service and protocol
+ are specified. The returned addresses will be sorted according to
+ the SRV priorities and weights.
+
+ If no resolver is provided, the dnspython resolver will be used if
+ available. Otherwise the built-in socket facilities will be used,
+ but those do not provide SRV support.
+
+ If SRV records were used, queries to resolve alternative hosts will
+ be made as needed instead of all at once.
+
+ :param host: The hostname to resolve.
+ :param port: A default port to connect with. SRV records may
+ dictate use of a different port.
+ :param service: Optional SRV service name without leading underscore.
+ :param proto: Optional SRV protocol name without leading underscore.
+ :param resolver: Optionally provide a DNS resolver object that has
+ been custom configured.
+
+ :type host: string
+ :type port: int
+ :type service: string
+ :type proto: string
+ :type resolver: :class:`dns.resolver.Resolver`
+
+ :return: An iterable of IP address, port pairs in the order
+ dictated by SRV priorities and weights, if applicable.
+ """
+ if resolver is None and USE_DNSPYTHON:
+ resolver = dns.resolver.get_default_resolver()
+
+ # An IPv6 literal is allowed to be enclosed in square brackets, but
+ # the brackets must be stripped in order to process the literal;
+ # otherwise, things break.
+ host = host.strip('[]')
+
+ try:
+ # If `host` is an IPv4 literal, we can return it immediately.
+ ipv4 = socket.inet_pton(socket.AF_INET, host)
+ yield [(host, port)]
+ except socket.error:
+ pass
+
+ try:
+ # Likewise, If `host` is an IPv6 literal, we can return it immediately.
+ ipv6 = socket.inet_pton(socket.AF_INET6, host)
+ yield [(host, port)]
+ except socket.error:
+ pass
+
+ # If no service was provided, then we can just do A/AAAA lookups on the
+ # provided host. Otherwise we need to get an ordered list of hosts to
+ # resolve based on SRV records.
+ if not service:
+ hosts = [(host, port)]
+ else:
+ hosts = get_SRV(host, port, service, proto, resolver=resolver)
+
+ for host, port in hosts:
+ results = []
+ for address in get_AAAA(host, resolver=resolver):
+ results.append((address, port))
+ for address in get_A(host, resolver=resolver):
+ results.append((address, port))
+
+ for address, port in results:
+ yield address, port
+
+
+def get_A(host, resolver=None):
+ """Lookup DNS A records for a given host.
+
+ If ``resolver`` is not provided, or is ``None``, then resolution will
+ be performed using the built-in :mod:`socket` module.
+
+ :param host: The hostname to resolve for A record IPv4 addresses.
+ :param resolver: Optional DNS resolver object to use for the query.
+
+ :type host: string
+ :type resolver: :class:`dns.resolver.Resolver` or ``None``
+
+ :return: A list of IPv4 literals.
+ """
+ log.debug("DNS: Querying %s for A records." % host)
+
+ # If not using dnspython, attempt lookup using the OS level
+ # getaddrinfo() method.
+ if resolver is None:
+ try:
+ recs = socket.getaddrinfo(host, None, socket.AF_INET,
+ socket.SOCK_STREAM)
+ return [rec[4][0] for rec in recs]
+ except socket.gaierror:
+ log.debug("DNS: Error retreiving A address info for %s." % host)
+ return []
+
+ # Using dnspython:
+ try:
+ recs = resolver.query(host, dns.rdatatype.A)
+ return [rec.to_text() for rec in recs]
+ except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+ log.debug("DNS: No A records for %s." % host)
+ return []
+ except dns.exception.Timeout:
+ log.debug("DNS: A record resolution timed out for %s." % host)
+ return []
+ except dns.exception.DNSException as e:
+ log.debug("DNS: Error querying A records for %s." % host)
+ log.exception(e)
+ return []
+
+
+def get_AAAA(host, resolver=None):
+ """Lookup DNS AAAA records for a given host.
+
+ If ``resolver`` is not provided, or is ``None``, then resolution will
+ be performed using the built-in :mod:`socket` module.
+
+ :param host: The hostname to resolve for AAAA record IPv6 addresses.
+ :param resolver: Optional DNS resolver object to use for the query.
+
+ :type host: string
+ :type resolver: :class:`dns.resolver.Resolver` or ``None``
+
+ :return: A list of IPv6 literals.
+ """
+ log.debug("DNS: Querying %s for AAAA records." % host)
+
+ # If not using dnspython, attempt lookup using the OS level
+ # getaddrinfo() method.
+ if resolver is None:
+ try:
+ recs = socket.getaddrinfo(host, None, socket.AF_INET6,
+ socket.SOCK_STREAM)
+ return [rec[4][0] for rec in recs]
+ except socket.gaierror:
+ log.debug("DNS: Error retreiving AAAA address " + \
+ "info for %s." % host)
+ return []
+
+ # Using dnspython:
+ try:
+ recs = resolver.query(host, dns.rdatatype.AAAA)
+ return [rec.to_text() for rec in recs]
+ except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+ log.debug("DNS: No AAAA records for %s." % host)
+ return []
+ except dns.exception.Timeout:
+ log.debug("DNS: AAAA record resolution timed out for %s." % host)
+ return []
+ except dns.exception.DNSException as e:
+ log.debug("DNS: Error querying AAAA records for %s." % host)
+ log.exception(e)
+ return []
+
+
+def get_SRV(host, port, service, proto='tcp', resolver=None):
+ """Perform SRV record resolution for a given host.
+
+ .. note::
+
+ This function requires the use of the ``dnspython`` package. Calling
+ :func:`get_SRV` without ``dnspython`` will return the provided host
+ and port without performing any DNS queries.
+
+ :param host: The hostname to resolve.
+ :param port: A default port to connect with. SRV records may
+ dictate use of a different port.
+ :param service: Optional SRV service name without leading underscore.
+ :param proto: Optional SRV protocol name without leading underscore.
+ :param resolver: Optionally provide a DNS resolver object that has
+ been custom configured.
+
+ :type host: string
+ :type port: int
+ :type service: string
+ :type proto: string
+ :type resolver: :class:`dns.resolver.Resolver`
+
+ :return: A list of hostname, port pairs in the order dictacted
+ by SRV priorities and weights.
+ """
+ if resolver is None:
+ return [(host, port)]
+
+ log.debug("Querying SRV records for %s" % host)
+ try:
+ recs = resolver.query('_%s._%s.%s' % (service, proto, host),
+ dns.rdatatype.SRV)
+ except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+ log.debug("DNS: No SRV records for %s." % host)
+ return [(host, port)]
+ except dns.exception.Timeout:
+ log.debug("DNS: SRV record resolution timed out for %s." % host)
+ return [(host, port)]
+ except dns.exception.DNSException as e:
+ log.debug("DNS: Error querying SRV records for %s." % host)
+ log.exception(e)
+ return [(host, port)]
+
+ if len(recs) == 1 and recs[0].target == '.':
+ return [(host, port)]
+
+ answers = {}
+ for rec in recs:
+ if rec.priority not in answers:
+ answers[rec.priority] = []
+ if rec.weight == 0:
+ answers[rec.priority].insert(0, rec)
+ else:
+ answers[rec.priority].append(rec)
+
+ sorted_recs = []
+ for priority in sorted(answers.keys()):
+ while answers[priority]:
+ running_sum = 0
+ sums = {}
+ for rec in answers[priority]:
+ running_sum += rec.weight
+ sums[running_sum] = rec
+
+ selected = random.randint(0, running_sum + 1)
+ for running_sum in sums:
+ if running_sum >= selected:
+ rec = sums[running_sum]
+ sorted_recs.append((rec.target.to_text(), rec.port))
+ answers[priority].remove(rec)
+ break
+
+ return sorted_recs
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index 45f27929..31ca9cfe 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""
sleekxmpp.xmlstream.xmlstream
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -39,19 +38,13 @@ from sleekxmpp.xmlstream import Scheduler, tostring
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
from sleekxmpp.xmlstream.handler import Waiter, XMLCallback
from sleekxmpp.xmlstream.matcher import MatchXMLMask
+from sleekxmpp.xmlstream.resolver import resolve, default_resolver
# In Python 2.x, file socket objects are broken. A patched socket
# wrapper is provided for this case in filesocket.py.
if sys.version_info < (3, 0):
from sleekxmpp.xmlstream.filesocket import FileSocket, Socket26
-try:
- import dns.resolver
-except ImportError:
- DNSPYTHON = False
-else:
- DNSPYTHON = True
-
#: The time in seconds to wait before timing out waiting for response stanzas.
RESPONSE_TIMEOUT = 30
@@ -306,6 +299,11 @@ class XMLStream(object):
#: A list of DNS results that have not yet been tried.
self.dns_answers = []
+ #: The service name to check with DNS SRV records. For
+ #: example, setting this to ``'xmpp-client'`` would query the
+ #: ``_xmpp-client._tcp`` service.
+ self.dns_service = None
+
self.add_event_handler('connected', self._handle_connected)
self.add_event_handler('session_start', self._start_keepalive)
self.add_event_handler('disconnected', self._end_keepalive)
@@ -445,25 +443,10 @@ class XMLStream(object):
self.stop.set()
return False
- try:
- # Look for IPv6 addresses, in addition to IPv4
- for res in Socket.getaddrinfo(self.address[0],
- int(self.address[1]),
- 0,
- Socket.SOCK_STREAM):
- log.debug("Trying: %s", res[-1])
- af, sock_type, proto, canonical, sock_addr = res
- try:
- self.socket = self.socket_class(af, sock_type, proto)
- break
- except Socket.error:
- log.debug("Could not open IPv%s socket." % proto)
- except Socket.gaierror:
- log.warning("Socket could not be opened: no connectivity" + \
- " or wrong IP versions.")
- if reattempt:
- self.reconnect_delay = delay
- return False
+ af = Socket.AF_INET
+ if ':' in self.address[0]:
+ af = Socket.AF_INET6
+ self.socket = self.socket_class(af, Socket.SOCK_STREAM)
self.configure_socket()
@@ -511,7 +494,10 @@ class XMLStream(object):
except Socket.error as serr:
error_msg = "Could not connect to %s:%s. Socket Error #%s: %s"
self.event('socket_error', serr, direct=True)
- log.error(error_msg, self.address[0], self.address[1],
+ domain = self.address[0]
+ if ':' in domain:
+ domain = '[%s]' % domain
+ log.error(error_msg, domain, self.address[1],
serr.errno, serr.strerror)
if reattempt:
self.reconnect_delay = delay
@@ -915,50 +901,11 @@ class XMLStream(object):
"""
if port is None:
port = self.default_port
- if DNSPYTHON:
- resolver = dns.resolver.get_default_resolver()
- self.configure_dns(resolver, domain=domain, port=port)
-
- v4_answers = []
- v6_answers = []
- answers = []
-
- try:
- log.debug("Querying A records for %s" % domain)
- v4_answers = resolver.query(domain, dns.rdatatype.A)
- except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
- log.warning("No A records for %s", domain)
- v4_answers = [((domain, port), 0, 0)]
- except dns.exception.Timeout:
- log.warning("DNS resolution timed out " + \
- "for A record of %s", domain)
- v4_answers = [((domain, port), 0, 0)]
- else:
- for ans in v4_answers:
- log.debug("Found A record: %s", ans.address)
- answers.append(((ans.address, port), 0, 0))
-
- try:
- log.debug("Querying AAAA records for %s" % domain)
- v6_answers = resolver.query(domain, dns.rdatatype.AAAA)
- except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
- log.warning("No AAAA records for %s", domain)
- v6_answers = [((domain, port), 0, 0)]
- except dns.exception.Timeout:
- log.warning("DNS resolution timed out " + \
- "for AAAA record of %s", domain)
- v6_answers = [((domain, port), 0, 0)]
- else:
- for ans in v6_answers:
- log.debug("Found AAAA record: %s", ans.address)
- answers.append(((ans.address, port), 0, 0))
+
+ resolver = default_resolver()
+ self.configure_dns(resolver, domain=domain, port=port)
- return answers
- else:
- log.warning("dnspython is not installed -- " + \
- "relying on OS A/AAAA record resolution")
- self.configure_dns(None, domain=domain, port=port)
- return [((domain, port), 0, 0)]
+ return resolve(domain, port, service=self.dns_service, resolver=resolver)
def pick_dns_answer(self, domain, port=None):
"""Pick a server and port from DNS answers.
@@ -971,33 +918,16 @@ class XMLStream(object):
"""
if not self.dns_answers:
self.dns_answers = self.get_dns_records(domain, port)
- addresses = {}
- intmax = 0
- topprio = 65535
- for answer in self.dns_answers:
- topprio = min(topprio, answer[1])
- for answer in self.dns_answers:
- if answer[1] == topprio:
- intmax += answer[2]
- addresses[intmax] = answer[0]
-
- #python3 returns a generator for dictionary keys
- items = [x for x in addresses.keys()]
- items.sort()
-
- address = (domain, port)
- picked = random.randint(0, intmax)
- for item in items:
- if picked <= item:
- address = addresses[item]
- break
- for idx, answer in enumerate(self.dns_answers):
- if self.dns_answers[0] == address:
- self.dns_answers.pop(idx)
- break
- log.debug("Trying to connect to %s:%s", *address)
- return address
+ try:
+ if sys.version_info < (3, 0):
+ return self.dns_answers.next()
+ else:
+ return next(self.dns_answers)
+ except StopIteration:
+ self.dns_answers = None
+ return (domain, port)
+
def add_event_handler(self, name, pointer,
threaded=False, disposable=False):
"""Add a custom event handler that will be executed whenever