summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--slixmpp/plugins/xep_0444/stanza.py4
-rw-r--r--slixmpp/xmlstream/resolver.py36
-rw-r--r--slixmpp/xmlstream/xmlstream.py4
3 files changed, 23 insertions, 21 deletions
diff --git a/slixmpp/plugins/xep_0444/stanza.py b/slixmpp/plugins/xep_0444/stanza.py
index 02684df1..c9ee07d7 100644
--- a/slixmpp/plugins/xep_0444/stanza.py
+++ b/slixmpp/plugins/xep_0444/stanza.py
@@ -6,9 +6,7 @@
from typing import Set, Iterable
from slixmpp.xmlstream import ElementBase
try:
- from emoji import UNICODE_EMOJI
- if UNICODE_EMOJI.get('en'):
- UNICODE_EMOJI = UNICODE_EMOJI['en']
+ from emoji import EMOJI_DATA as UNICODE_EMOJI
except ImportError:
UNICODE_EMOJI = None
diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py
index e524da3b..3de6629d 100644
--- a/slixmpp/xmlstream/resolver.py
+++ b/slixmpp/xmlstream/resolver.py
@@ -15,7 +15,13 @@ from slixmpp.types import Protocol
log = logging.getLogger(__name__)
-class AnswerProtocol(Protocol):
+class GetHostByNameAnswerProtocol(Protocol):
+ name: str
+ aliases: List[str]
+ addresses: List[str]
+
+
+class QueryAnswerProtocol(Protocol):
host: str
priority: int
weight: int
@@ -23,6 +29,9 @@ class AnswerProtocol(Protocol):
class ResolverProtocol(Protocol):
+ def gethostbyname(self, host: str, socket_family: socket.AddressFamily) -> Future:
+ ...
+
def query(self, query: str, querytype: str) -> Future:
...
@@ -147,11 +156,6 @@ async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
results = []
for host, port in hosts:
- if host == 'localhost':
- if use_ipv6:
- results.append((host, '::1', port))
- results.append((host, '127.0.0.1', port))
-
if use_ipv6:
aaaa = await get_AAAA(host, resolver=resolver,
use_aiodns=use_aiodns, loop=loop)
@@ -201,13 +205,13 @@ async def get_A(host: str, *, loop: AbstractEventLoop,
return []
# Using aiodns:
- future = resolver.query(host, 'A')
+ future = resolver.gethostbyname(host, socket.AF_INET)
try:
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
- recs = []
- return [rec.host for rec in recs]
+ return []
+ return [addr for addr in recs.addresses]
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
@@ -249,13 +253,13 @@ async def get_AAAA(host: str, *, loop: AbstractEventLoop,
return []
# Using aiodns:
- future = resolver.query(host, 'AAAA')
+ future = resolver.gethostbyname(host, socket.AF_INET6)
try:
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
- recs = []
- return [rec.host for rec in recs]
+ return []
+ return [addr for addr in recs.addresses]
async def get_SRV(host: str, port: int, service: str,
@@ -295,12 +299,12 @@ async def get_SRV(host: str, port: int, service: str,
try:
future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV')
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(Iterable[QueryAnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return []
- answers: Dict[int, List[AnswerProtocol]] = {}
+ answers: Dict[int, List[QueryAnswerProtocol]] = {}
for rec in recs:
if rec.priority not in answers:
answers[rec.priority] = []
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 18464ccd..19c4ddcc 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -574,7 +574,7 @@ class XMLStream(asyncio.BaseProtocol):
stream=self,
top_level=True,
open_only=True))
- self.start_stream_handler(self.xml_root)
+ self.start_stream_handler(self.xml_root) # type:ignore
self.xml_depth += 1
if event == 'end':
self.xml_depth -= 1
@@ -1267,7 +1267,7 @@ class XMLStream(asyncio.BaseProtocol):
already_run_filters.add(filter)
if iscoroutinefunction(filter):
filter = cast(AsyncFilter, filter)
- task = asyncio.create_task(filter(data))
+ task = asyncio.create_task(filter(data)) # type:ignore
completed, pending = await wait(
{task},
timeout=1,