From 4931e7e6041ff6c59817d6807b0776a8994d0377 Mon Sep 17 00:00:00 2001
From: mathieui <mathieui@mathieui.net>
Date: Tue, 20 Apr 2021 22:14:01 +0200
Subject: refactor: type the resolver
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

almost perfect, except for python < 3.9 making it so we can’t have nice
things.
---
 slixmpp/types.py              |  2 ++
 slixmpp/xmlstream/resolver.py | 70 +++++++++++++++++++++++++++++--------------
 2 files changed, 50 insertions(+), 22 deletions(-)

diff --git a/slixmpp/types.py b/slixmpp/types.py
index 453d25e3..2a25f7ad 100644
--- a/slixmpp/types.py
+++ b/slixmpp/types.py
@@ -16,11 +16,13 @@ try:
     from typing import (
         Literal,
         TypedDict,
+        Protocol,
     )
 except ImportError:
     from typing_extensions import (
         Literal,
         TypedDict,
+        Protocol,
     )
 
 from slixmpp.jid import JID
diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py
index 97798353..e524da3b 100644
--- a/slixmpp/xmlstream/resolver.py
+++ b/slixmpp/xmlstream/resolver.py
@@ -1,18 +1,32 @@
-
 # slixmpp.xmlstream.dns
 # ~~~~~~~~~~~~~~~~~~~~~~~
 # :copyright: (c) 2012 Nathanael C. Fritz
 # :license: MIT, see LICENSE for more details
 
-from slixmpp.xmlstream.asyncio import asyncio
 import socket
+import sys
 import logging
 import random
+from asyncio import Future, AbstractEventLoop
+from typing import Optional, Tuple, Dict, List, Iterable, cast
+from slixmpp.types import Protocol
 
 
 log = logging.getLogger(__name__)
 
 
+class AnswerProtocol(Protocol):
+    host: str
+    priority: int
+    weight: int
+    port: int
+
+
+class ResolverProtocol(Protocol):
+    def query(self, query: str, querytype: str) -> Future:
+        ...
+
+
 #: Global flag indicating the availability of the ``aiodns`` package.
 #: Installing ``aiodns`` can be done via:
 #:
@@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False
 try:
     import aiodns
     AIODNS_AVAILABLE = True
-except ImportError as e:
-    log.debug("Could not find aiodns package. " + \
+except ImportError:
+    log.debug("Could not find aiodns package. "
               "Not all features will be available")
 
 
-def default_resolver(loop):
+def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]:
     """Return a basic DNS resolver object.
 
     :returns: A :class:`aiodns.DNSResolver` object if aiodns
@@ -41,8 +55,11 @@ def default_resolver(loop):
     return None
 
 
-async def resolve(host, port=None, service=None, proto='tcp',
-            resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
+async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
+                  service: Optional[str] = None, proto: str = 'tcp',
+                  resolver: Optional[ResolverProtocol] = None,
+                  use_ipv6: bool = True,
+                  use_aiodns: bool = True) -> List[Tuple[str, str, int]]:
     """Peform DNS resolution for a given hostname.
 
     Resolution may perform SRV record lookups if a service and protocol
@@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp',
     if not use_ipv6:
         log.debug("DNS: Use of IPv6 has been disabled.")
 
-    if resolver is None and AIODNS_AVAILABLE and use_aiodns:
-        resolver = aiodns.DNSResolver(loop=loop)
+    if resolver is None and use_aiodns:
+        resolver = default_resolver(loop=loop)
 
     # An IPv6 literal is allowed to be enclosed in square brackets, but
     # the brackets must be stripped in order to process the literal;
@@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
 
     try:
         # If `host` is an IPv4 literal, we can return it immediately.
-        ipv4 = socket.inet_aton(host)
+        socket.inet_aton(host)
         return [(host, host, port)]
     except socket.error:
         pass
@@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
             # Likewise, If `host` is an IPv6 literal, we can return
             # it immediately.
             if hasattr(socket, 'inet_pton'):
-                ipv6 = socket.inet_pton(socket.AF_INET6, host)
+                socket.inet_pton(socket.AF_INET6, host)
                 return [(host, host, port)]
         except (socket.error, ValueError):
             pass
@@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp',
 
     return results
 
-async def get_A(host, resolver=None, use_aiodns=True, loop=None):
+
+async def get_A(host: str, *, loop: AbstractEventLoop,
+                resolver: Optional[ResolverProtocol] = None,
+                use_aiodns: bool = True) -> List[str]:
     """Lookup DNS A records for a given host.
 
     If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
     # getaddrinfo() method.
     if resolver is None or not use_aiodns:
         try:
-            recs = await loop.getaddrinfo(host, None,
+            inet_recs = await loop.getaddrinfo(host, None,
                                                family=socket.AF_INET,
                                                type=socket.SOCK_STREAM)
-            return [rec[4][0] for rec in recs]
+            return [rec[4][0] for rec in inet_recs]
         except socket.gaierror:
             log.debug("DNS: Error retrieving A address info for %s." % host)
             return []
@@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
     # Using aiodns:
     future = resolver.query(host, 'A')
     try:
-        recs = await future
+        recs = cast(Iterable[AnswerProtocol], await future)
     except Exception as e:
         log.debug('DNS: Exception while querying for %s A records: %s', host, e)
         recs = []
     return [rec.host for rec in recs]
 
 
-async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
+async def get_AAAA(host: str, *, loop: AbstractEventLoop,
+                   resolver: Optional[ResolverProtocol] = None,
+                   use_aiodns: bool = True) -> List[str]:
     """Lookup DNS AAAA records for a given host.
 
     If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
             log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
             return []
         try:
-            recs = await loop.getaddrinfo(host, None,
+            inet_recs = await loop.getaddrinfo(host, None,
                                                family=socket.AF_INET6,
                                                type=socket.SOCK_STREAM)
-            return [rec[4][0] for rec in recs]
+            return [rec[4][0] for rec in inet_recs]
         except (OSError, socket.gaierror):
             log.debug("DNS: Error retrieving AAAA address " + \
                       "info for %s." % host)
@@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
     # Using aiodns:
     future = resolver.query(host, 'AAAA')
     try:
-        recs = await future
+        recs = cast(Iterable[AnswerProtocol], await future)
     except Exception as e:
         log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
         recs = []
     return [rec.host for rec in recs]
 
-async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
+
+async def get_SRV(host: str, port: int, service: str,
+                  proto: str = 'tcp',
+                  resolver: Optional[ResolverProtocol] = None,
+                  use_aiodns: bool = True) -> List[Tuple[str, int]]:
     """Perform SRV record resolution for a given host.
 
     .. note::
@@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr
     try:
         future = resolver.query('_%s._%s.%s' % (service, proto, host),
                                 'SRV')
-        recs = await future
+        recs = cast(Iterable[AnswerProtocol], await future)
     except Exception as e:
         log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
         return []
 
-    answers = {}
+    answers: Dict[int, List[AnswerProtocol]] = {}
     for rec in recs:
         if rec.priority not in answers:
             answers[rec.priority] = []
-- 
cgit v1.2.3