From 4931e7e6041ff6c59817d6807b0776a8994d0377 Mon Sep 17 00:00:00 2001 From: mathieui Date: Tue, 20 Apr 2021 22:14:01 +0200 Subject: refactor: type the resolver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit almost perfect, except for python < 3.9 making it so we can’t have nice things. --- slixmpp/types.py | 2 ++ slixmpp/xmlstream/resolver.py | 70 +++++++++++++++++++++++++++++-------------- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/slixmpp/types.py b/slixmpp/types.py index 453d25e3..2a25f7ad 100644 --- a/slixmpp/types.py +++ b/slixmpp/types.py @@ -16,11 +16,13 @@ try: from typing import ( Literal, TypedDict, + Protocol, ) except ImportError: from typing_extensions import ( Literal, TypedDict, + Protocol, ) from slixmpp.jid import JID diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py index 97798353..e524da3b 100644 --- a/slixmpp/xmlstream/resolver.py +++ b/slixmpp/xmlstream/resolver.py @@ -1,18 +1,32 @@ - # slixmpp.xmlstream.dns # ~~~~~~~~~~~~~~~~~~~~~~~ # :copyright: (c) 2012 Nathanael C. Fritz # :license: MIT, see LICENSE for more details -from slixmpp.xmlstream.asyncio import asyncio import socket +import sys import logging import random +from asyncio import Future, AbstractEventLoop +from typing import Optional, Tuple, Dict, List, Iterable, cast +from slixmpp.types import Protocol log = logging.getLogger(__name__) +class AnswerProtocol(Protocol): + host: str + priority: int + weight: int + port: int + + +class ResolverProtocol(Protocol): + def query(self, query: str, querytype: str) -> Future: + ... + + #: Global flag indicating the availability of the ``aiodns`` package. #: Installing ``aiodns`` can be done via: #: @@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False try: import aiodns AIODNS_AVAILABLE = True -except ImportError as e: - log.debug("Could not find aiodns package. " + \ +except ImportError: + log.debug("Could not find aiodns package. " "Not all features will be available") -def default_resolver(loop): +def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]: """Return a basic DNS resolver object. :returns: A :class:`aiodns.DNSResolver` object if aiodns @@ -41,8 +55,11 @@ def default_resolver(loop): return None -async def resolve(host, port=None, service=None, proto='tcp', - resolver=None, use_ipv6=True, use_aiodns=True, loop=None): +async def resolve(host: str, port: int, *, loop: AbstractEventLoop, + service: Optional[str] = None, proto: str = 'tcp', + resolver: Optional[ResolverProtocol] = None, + use_ipv6: bool = True, + use_aiodns: bool = True) -> List[Tuple[str, str, int]]: """Peform DNS resolution for a given hostname. Resolution may perform SRV record lookups if a service and protocol @@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp', if not use_ipv6: log.debug("DNS: Use of IPv6 has been disabled.") - if resolver is None and AIODNS_AVAILABLE and use_aiodns: - resolver = aiodns.DNSResolver(loop=loop) + if resolver is None and use_aiodns: + resolver = default_resolver(loop=loop) # An IPv6 literal is allowed to be enclosed in square brackets, but # the brackets must be stripped in order to process the literal; @@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp', try: # If `host` is an IPv4 literal, we can return it immediately. - ipv4 = socket.inet_aton(host) + socket.inet_aton(host) return [(host, host, port)] except socket.error: pass @@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp', # Likewise, If `host` is an IPv6 literal, we can return # it immediately. if hasattr(socket, 'inet_pton'): - ipv6 = socket.inet_pton(socket.AF_INET6, host) + socket.inet_pton(socket.AF_INET6, host) return [(host, host, port)] except (socket.error, ValueError): pass @@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp', return results -async def get_A(host, resolver=None, use_aiodns=True, loop=None): + +async def get_A(host: str, *, loop: AbstractEventLoop, + resolver: Optional[ResolverProtocol] = None, + use_aiodns: bool = True) -> List[str]: """Lookup DNS A records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None): # getaddrinfo() method. if resolver is None or not use_aiodns: try: - recs = await loop.getaddrinfo(host, None, + inet_recs = await loop.getaddrinfo(host, None, family=socket.AF_INET, type=socket.SOCK_STREAM) - return [rec[4][0] for rec in recs] + return [rec[4][0] for rec in inet_recs] except socket.gaierror: log.debug("DNS: Error retrieving A address info for %s." % host) return [] @@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None): # Using aiodns: future = resolver.query(host, 'A') try: - recs = await future + recs = cast(Iterable[AnswerProtocol], await future) except Exception as e: log.debug('DNS: Exception while querying for %s A records: %s', host, e) recs = [] return [rec.host for rec in recs] -async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): +async def get_AAAA(host: str, *, loop: AbstractEventLoop, + resolver: Optional[ResolverProtocol] = None, + use_aiodns: bool = True) -> List[str]: """Lookup DNS AAAA records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host) return [] try: - recs = await loop.getaddrinfo(host, None, + inet_recs = await loop.getaddrinfo(host, None, family=socket.AF_INET6, type=socket.SOCK_STREAM) - return [rec[4][0] for rec in recs] + return [rec[4][0] for rec in inet_recs] except (OSError, socket.gaierror): log.debug("DNS: Error retrieving AAAA address " + \ "info for %s." % host) @@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): # Using aiodns: future = resolver.query(host, 'AAAA') try: - recs = await future + recs = cast(Iterable[AnswerProtocol], await future) except Exception as e: log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e) recs = [] return [rec.host for rec in recs] -async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True): + +async def get_SRV(host: str, port: int, service: str, + proto: str = 'tcp', + resolver: Optional[ResolverProtocol] = None, + use_aiodns: bool = True) -> List[Tuple[str, int]]: """Perform SRV record resolution for a given host. .. note:: @@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr try: future = resolver.query('_%s._%s.%s' % (service, proto, host), 'SRV') - recs = await future + recs = cast(Iterable[AnswerProtocol], await future) except Exception as e: log.debug('DNS: Exception while querying for %s SRV records: %s', host, e) return [] - answers = {} + answers: Dict[int, List[AnswerProtocol]] = {} for rec in recs: if rec.priority not in answers: answers[rec.priority] = [] -- cgit v1.2.3