summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieui <mathieui@mathieui.net>2022-09-09 16:04:14 +0000
committermathieui <mathieui@mathieui.net>2022-09-09 16:04:14 +0000
commitb3a6c7a4ea5af197136dd8a6a0e6013aeb50e8f6 (patch)
tree75aa4036015ae4b831ef5b6ea2323218b1b11584
parent11e27d1d7d1d15ff404ed9bff7b28c2d1f8ffd16 (diff)
parentd43c83800e51c2455b5070a1ccaca56b57fb1575 (diff)
downloadslixmpp-b3a6c7a4ea5af197136dd8a6a0e6013aeb50e8f6.tar.gz
slixmpp-b3a6c7a4ea5af197136dd8a6a0e6013aeb50e8f6.tar.bz2
slixmpp-b3a6c7a4ea5af197136dd8a6a0e6013aeb50e8f6.tar.xz
slixmpp-b3a6c7a4ea5af197136dd8a6a0e6013aeb50e8f6.zip
Merge branch 'aiodns-gethostbyname' into 'master'
Use gethostbyname when using aiodns See merge request poezio/slixmpp!212
-rw-r--r--slixmpp/xmlstream/resolver.py36
1 files changed, 20 insertions, 16 deletions
diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py
index e524da3b..3de6629d 100644
--- a/slixmpp/xmlstream/resolver.py
+++ b/slixmpp/xmlstream/resolver.py
@@ -15,7 +15,13 @@ from slixmpp.types import Protocol
log = logging.getLogger(__name__)
-class AnswerProtocol(Protocol):
+class GetHostByNameAnswerProtocol(Protocol):
+ name: str
+ aliases: List[str]
+ addresses: List[str]
+
+
+class QueryAnswerProtocol(Protocol):
host: str
priority: int
weight: int
@@ -23,6 +29,9 @@ class AnswerProtocol(Protocol):
class ResolverProtocol(Protocol):
+ def gethostbyname(self, host: str, socket_family: socket.AddressFamily) -> Future:
+ ...
+
def query(self, query: str, querytype: str) -> Future:
...
@@ -147,11 +156,6 @@ async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
results = []
for host, port in hosts:
- if host == 'localhost':
- if use_ipv6:
- results.append((host, '::1', port))
- results.append((host, '127.0.0.1', port))
-
if use_ipv6:
aaaa = await get_AAAA(host, resolver=resolver,
use_aiodns=use_aiodns, loop=loop)
@@ -201,13 +205,13 @@ async def get_A(host: str, *, loop: AbstractEventLoop,
return []
# Using aiodns:
- future = resolver.query(host, 'A')
+ future = resolver.gethostbyname(host, socket.AF_INET)
try:
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
- recs = []
- return [rec.host for rec in recs]
+ return []
+ return [addr for addr in recs.addresses]
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
@@ -249,13 +253,13 @@ async def get_AAAA(host: str, *, loop: AbstractEventLoop,
return []
# Using aiodns:
- future = resolver.query(host, 'AAAA')
+ future = resolver.gethostbyname(host, socket.AF_INET6)
try:
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
- recs = []
- return [rec.host for rec in recs]
+ return []
+ return [addr for addr in recs.addresses]
async def get_SRV(host: str, port: int, service: str,
@@ -295,12 +299,12 @@ async def get_SRV(host: str, port: int, service: str,
try:
future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV')
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(Iterable[QueryAnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return []
- answers: Dict[int, List[AnswerProtocol]] = {}
+ answers: Dict[int, List[QueryAnswerProtocol]] = {}
for rec in recs:
if rec.priority not in answers:
answers[rec.priority] = []