summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--slixmpp/types.py2
-rw-r--r--slixmpp/xmlstream/resolver.py70
2 files changed, 50 insertions, 22 deletions
diff --git a/slixmpp/types.py b/slixmpp/types.py
index 453d25e3..2a25f7ad 100644
--- a/slixmpp/types.py
+++ b/slixmpp/types.py
@@ -16,11 +16,13 @@ try:
from typing import (
Literal,
TypedDict,
+ Protocol,
)
except ImportError:
from typing_extensions import (
Literal,
TypedDict,
+ Protocol,
)
from slixmpp.jid import JID
diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py
index 97798353..e524da3b 100644
--- a/slixmpp/xmlstream/resolver.py
+++ b/slixmpp/xmlstream/resolver.py
@@ -1,18 +1,32 @@
-
# slixmpp.xmlstream.dns
# ~~~~~~~~~~~~~~~~~~~~~~~
# :copyright: (c) 2012 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
-from slixmpp.xmlstream.asyncio import asyncio
import socket
+import sys
import logging
import random
+from asyncio import Future, AbstractEventLoop
+from typing import Optional, Tuple, Dict, List, Iterable, cast
+from slixmpp.types import Protocol
log = logging.getLogger(__name__)
+class AnswerProtocol(Protocol):
+ host: str
+ priority: int
+ weight: int
+ port: int
+
+
+class ResolverProtocol(Protocol):
+ def query(self, query: str, querytype: str) -> Future:
+ ...
+
+
#: Global flag indicating the availability of the ``aiodns`` package.
#: Installing ``aiodns`` can be done via:
#:
@@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False
try:
import aiodns
AIODNS_AVAILABLE = True
-except ImportError as e:
- log.debug("Could not find aiodns package. " + \
+except ImportError:
+ log.debug("Could not find aiodns package. "
"Not all features will be available")
-def default_resolver(loop):
+def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]:
"""Return a basic DNS resolver object.
:returns: A :class:`aiodns.DNSResolver` object if aiodns
@@ -41,8 +55,11 @@ def default_resolver(loop):
return None
-async def resolve(host, port=None, service=None, proto='tcp',
- resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
+async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
+ service: Optional[str] = None, proto: str = 'tcp',
+ resolver: Optional[ResolverProtocol] = None,
+ use_ipv6: bool = True,
+ use_aiodns: bool = True) -> List[Tuple[str, str, int]]:
"""Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol
@@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp',
if not use_ipv6:
log.debug("DNS: Use of IPv6 has been disabled.")
- if resolver is None and AIODNS_AVAILABLE and use_aiodns:
- resolver = aiodns.DNSResolver(loop=loop)
+ if resolver is None and use_aiodns:
+ resolver = default_resolver(loop=loop)
# An IPv6 literal is allowed to be enclosed in square brackets, but
# the brackets must be stripped in order to process the literal;
@@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
try:
# If `host` is an IPv4 literal, we can return it immediately.
- ipv4 = socket.inet_aton(host)
+ socket.inet_aton(host)
return [(host, host, port)]
except socket.error:
pass
@@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
# Likewise, If `host` is an IPv6 literal, we can return
# it immediately.
if hasattr(socket, 'inet_pton'):
- ipv6 = socket.inet_pton(socket.AF_INET6, host)
+ socket.inet_pton(socket.AF_INET6, host)
return [(host, host, port)]
except (socket.error, ValueError):
pass
@@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp',
return results
-async def get_A(host, resolver=None, use_aiodns=True, loop=None):
+
+async def get_A(host: str, *, loop: AbstractEventLoop,
+ resolver: Optional[ResolverProtocol] = None,
+ use_aiodns: bool = True) -> List[str]:
"""Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# getaddrinfo() method.
if resolver is None or not use_aiodns:
try:
- recs = await loop.getaddrinfo(host, None,
+ inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET,
type=socket.SOCK_STREAM)
- return [rec[4][0] for rec in recs]
+ return [rec[4][0] for rec in inet_recs]
except socket.gaierror:
log.debug("DNS: Error retrieving A address info for %s." % host)
return []
@@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns:
future = resolver.query(host, 'A')
try:
- recs = await future
+ recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
recs = []
return [rec.host for rec in recs]
-async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
+async def get_AAAA(host: str, *, loop: AbstractEventLoop,
+ resolver: Optional[ResolverProtocol] = None,
+ use_aiodns: bool = True) -> List[str]:
"""Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return []
try:
- recs = await loop.getaddrinfo(host, None,
+ inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET6,
type=socket.SOCK_STREAM)
- return [rec[4][0] for rec in recs]
+ return [rec[4][0] for rec in inet_recs]
except (OSError, socket.gaierror):
log.debug("DNS: Error retrieving AAAA address " + \
"info for %s." % host)
@@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns:
future = resolver.query(host, 'AAAA')
try:
- recs = await future
+ recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
recs = []
return [rec.host for rec in recs]
-async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
+
+async def get_SRV(host: str, port: int, service: str,
+ proto: str = 'tcp',
+ resolver: Optional[ResolverProtocol] = None,
+ use_aiodns: bool = True) -> List[Tuple[str, int]]:
"""Perform SRV record resolution for a given host.
.. note::
@@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr
try:
future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV')
- recs = await future
+ recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return []
- answers = {}
+ answers: Dict[int, List[AnswerProtocol]] = {}
for rec in recs:
if rec.priority not in answers:
answers[rec.priority] = []