diff options
-rw-r--r-- | slixmpp/xmlstream/xmlstream.py | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 83a197de..e6af3faa 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -16,10 +16,12 @@ from typing import ( Any, Callable, Iterable, + Iterator, List, Optional, Set, Union, + Tuple, ) import functools @@ -212,7 +214,7 @@ class XMLStream(asyncio.BaseProtocol): self._current_connection_attempt = None #: A list of DNS results that have not yet been tried. - self.dns_answers = None + self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None #: The service name to check with DNS SRV records. For #: example, setting this to ``'xmpp-client'`` would query the @@ -315,7 +317,7 @@ class XMLStream(asyncio.BaseProtocol): self.event('reconnect_delay', self._connect_loop_wait) await asyncio.sleep(self._connect_loop_wait, loop=self.loop) - record = await self.pick_dns_answer(self.default_domain) + record = await self._pick_dns_answer(self.default_domain) if record is not None: host, address, dns_port = record port = dns_port if dns_port else self.address[1] @@ -324,7 +326,7 @@ class XMLStream(asyncio.BaseProtocol): else: # No DNS records left, stop iterating # and try (host, port) as a last resort - self.dns_answers = None + self._dns_answers = None if self.use_ssl: ssl_context = self.get_ssl_context() @@ -392,7 +394,7 @@ class XMLStream(asyncio.BaseProtocol): self._current_connection_attempt = None self.init_parser() self.send_raw(self.stream_header) - self.dns_answers = None + self._dns_answers = None def data_received(self, data): """Called when incoming data is received on the socket. @@ -777,7 +779,7 @@ class XMLStream(asyncio.BaseProtocol): idx += 1 return False - async def get_dns_records(self, domain, port=None): + async def get_dns_records(self, domain: str, port: Optional[int] = None) -> List[Tuple[str, str, int]]: """Get the DNS records for a domain. :param domain: The domain in question. @@ -797,7 +799,7 @@ class XMLStream(asyncio.BaseProtocol): loop=self.loop) return result - async def pick_dns_answer(self, domain, port=None): + async def _pick_dns_answer(self, domain: str, port: Optional[int] = None) -> Optional[Tuple[str, str, int]]: """Pick a server and port from DNS answers. Gets DNS answers if none available. @@ -806,12 +808,12 @@ class XMLStream(asyncio.BaseProtocol): :param domain: The domain in question. :param port: If the results don't include a port, use this one. """ - if self.dns_answers is None: + if self._dns_answers is None: dns_records = await self.get_dns_records(domain, port) - self.dns_answers = iter(dns_records) + self._dns_answers = iter(dns_records) try: - return next(self.dns_answers) + return next(self._dns_answers) except StopIteration: return |