summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--slixmpp/xmlstream/xmlstream.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 83a197de..e6af3faa 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -16,10 +16,12 @@ from typing import (
Any,
Callable,
Iterable,
+ Iterator,
List,
Optional,
Set,
Union,
+ Tuple,
)
import functools
@@ -212,7 +214,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt = None
#: A list of DNS results that have not yet been tried.
- self.dns_answers = None
+ self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
@@ -315,7 +317,7 @@ class XMLStream(asyncio.BaseProtocol):
self.event('reconnect_delay', self._connect_loop_wait)
await asyncio.sleep(self._connect_loop_wait, loop=self.loop)
- record = await self.pick_dns_answer(self.default_domain)
+ record = await self._pick_dns_answer(self.default_domain)
if record is not None:
host, address, dns_port = record
port = dns_port if dns_port else self.address[1]
@@ -324,7 +326,7 @@ class XMLStream(asyncio.BaseProtocol):
else:
# No DNS records left, stop iterating
# and try (host, port) as a last resort
- self.dns_answers = None
+ self._dns_answers = None
if self.use_ssl:
ssl_context = self.get_ssl_context()
@@ -392,7 +394,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt = None
self.init_parser()
self.send_raw(self.stream_header)
- self.dns_answers = None
+ self._dns_answers = None
def data_received(self, data):
"""Called when incoming data is received on the socket.
@@ -777,7 +779,7 @@ class XMLStream(asyncio.BaseProtocol):
idx += 1
return False
- async def get_dns_records(self, domain, port=None):
+ async def get_dns_records(self, domain: str, port: Optional[int] = None) -> List[Tuple[str, str, int]]:
"""Get the DNS records for a domain.
:param domain: The domain in question.
@@ -797,7 +799,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop)
return result
- async def pick_dns_answer(self, domain, port=None):
+ async def _pick_dns_answer(self, domain: str, port: Optional[int] = None) -> Optional[Tuple[str, str, int]]:
"""Pick a server and port from DNS answers.
Gets DNS answers if none available.
@@ -806,12 +808,12 @@ class XMLStream(asyncio.BaseProtocol):
:param domain: The domain in question.
:param port: If the results don't include a port, use this one.
"""
- if self.dns_answers is None:
+ if self._dns_answers is None:
dns_records = await self.get_dns_records(domain, port)
- self.dns_answers = iter(dns_records)
+ self._dns_answers = iter(dns_records)
try:
- return next(self.dns_answers)
+ return next(self._dns_answers)
except StopIteration:
return