From 3502480384bd7d9f4e4eb1a3b92e8df08f4e487c Mon Sep 17 00:00:00 2001 From: Emmanuel Gil Peyrot Date: Sun, 1 Jul 2018 18:46:33 +0200 Subject: Switch from @asyncio.coroutine to async def everywhere. --- slixmpp/xmlstream/handler/coroutine_callback.py | 5 ++--- slixmpp/xmlstream/handler/waiter.py | 5 ++--- slixmpp/xmlstream/resolver.py | 26 ++++++++++------------ slixmpp/xmlstream/xmlstream.py | 29 ++++++++++--------------- 4 files changed, 27 insertions(+), 38 deletions(-) (limited to 'slixmpp/xmlstream') diff --git a/slixmpp/xmlstream/handler/coroutine_callback.py b/slixmpp/xmlstream/handler/coroutine_callback.py index 1ca4ab0a..0708a6e4 100644 --- a/slixmpp/xmlstream/handler/coroutine_callback.py +++ b/slixmpp/xmlstream/handler/coroutine_callback.py @@ -45,10 +45,9 @@ class CoroutineCallback(BaseHandler): if not asyncio.iscoroutinefunction(pointer): raise ValueError("Given function is not a coroutine") - @asyncio.coroutine - def pointer_wrapper(stanza, *args, **kwargs): + async def pointer_wrapper(stanza, *args, **kwargs): try: - yield from pointer(stanza, *args, **kwargs) + await pointer(stanza, *args, **kwargs) except Exception as e: stanza.exception(e) diff --git a/slixmpp/xmlstream/handler/waiter.py b/slixmpp/xmlstream/handler/waiter.py index 8a4d74ea..b82fa5ca 100644 --- a/slixmpp/xmlstream/handler/waiter.py +++ b/slixmpp/xmlstream/handler/waiter.py @@ -50,8 +50,7 @@ class Waiter(BaseHandler): """Do not process this handler during the main event loop.""" pass - @asyncio.coroutine - def wait(self, timeout=None): + async def wait(self, timeout=None): """Block an event handler while waiting for a stanza to arrive. Be aware that this will impact performance if called from a @@ -70,7 +69,7 @@ class Waiter(BaseHandler): stanza = None try: - stanza = yield from self._payload.get() + stanza = await self._payload.get() except TimeoutError: log.warning("Timed out waiting for %s", self.name) self.stream().remove_handler(self.name) diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py index 23f7f039..3c3c9dda 100644 --- a/slixmpp/xmlstream/resolver.py +++ b/slixmpp/xmlstream/resolver.py @@ -45,8 +45,7 @@ def default_resolver(loop): return None -@asyncio.coroutine -def resolve(host, port=None, service=None, proto='tcp', +async def resolve(host, port=None, service=None, proto='tcp', resolver=None, use_ipv6=True, use_aiodns=True, loop=None): """Peform DNS resolution for a given hostname. @@ -127,7 +126,7 @@ def resolve(host, port=None, service=None, proto='tcp', if not service: hosts = [(host, port)] else: - hosts = yield from get_SRV(host, port, service, proto, + hosts = await get_SRV(host, port, service, proto, resolver=resolver, use_aiodns=use_aiodns) if not hosts: @@ -141,19 +140,18 @@ def resolve(host, port=None, service=None, proto='tcp', results.append((host, '127.0.0.1', port)) if use_ipv6: - aaaa = yield from get_AAAA(host, resolver=resolver, + aaaa = await get_AAAA(host, resolver=resolver, use_aiodns=use_aiodns, loop=loop) for address in aaaa: results.append((host, address, port)) - a = yield from get_A(host, resolver=resolver, + a = await get_A(host, resolver=resolver, use_aiodns=use_aiodns, loop=loop) for address in a: results.append((host, address, port)) return results -@asyncio.coroutine def get_A(host, resolver=None, use_aiodns=True, loop=None): """Lookup DNS A records for a given host. @@ -178,7 +176,7 @@ def get_A(host, resolver=None, use_aiodns=True, loop=None): # getaddrinfo() method. if resolver is None or not use_aiodns: try: - recs = yield from loop.getaddrinfo(host, None, + recs = await loop.getaddrinfo(host, None, family=socket.AF_INET, type=socket.SOCK_STREAM) return [rec[4][0] for rec in recs] @@ -189,15 +187,14 @@ def get_A(host, resolver=None, use_aiodns=True, loop=None): # Using aiodns: future = resolver.query(host, 'A') try: - recs = yield from future + recs = await future except Exception as e: log.debug('DNS: Exception while querying for %s A records: %s', host, e) recs = [] return [rec.host for rec in recs] -@asyncio.coroutine -def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): +async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): """Lookup DNS AAAA records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -224,7 +221,7 @@ def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host) return [] try: - recs = yield from loop.getaddrinfo(host, None, + recs = await loop.getaddrinfo(host, None, family=socket.AF_INET6, type=socket.SOCK_STREAM) return [rec[4][0] for rec in recs] @@ -236,14 +233,13 @@ def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): # Using aiodns: future = resolver.query(host, 'AAAA') try: - recs = yield from future + recs = await future except Exception as e: log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e) recs = [] return [rec.host for rec in recs] -@asyncio.coroutine -def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True): +async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True): """Perform SRV record resolution for a given host. .. note:: @@ -277,7 +273,7 @@ def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True): try: future = resolver.query('_%s._%s.%s' % (service, proto, host), 'SRV') - recs = yield from future + recs = await future except Exception as e: log.debug('DNS: Exception while querying for %s SRV records: %s', host, e) return [] diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 93652815..d5dce586 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -287,11 +287,10 @@ class XMLStream(asyncio.BaseProtocol): self.event("connecting") self._current_connection_attempt = asyncio.ensure_future(self._connect_routine()) - @asyncio.coroutine - def _connect_routine(self): + async def _connect_routine(self): self.event_when_connected = "connected" - record = yield from self.pick_dns_answer(self.default_domain) + record = await self.pick_dns_answer(self.default_domain) if record is not None: host, address, dns_port = record port = dns_port if dns_port else self.address[1] @@ -307,9 +306,9 @@ class XMLStream(asyncio.BaseProtocol): else: ssl_context = None - yield from asyncio.sleep(self.connect_loop_wait) + await asyncio.sleep(self.connect_loop_wait) try: - yield from self.loop.create_connection(lambda: self, + await self.loop.create_connection(lambda: self, self.address[0], self.address[1], ssl=ssl_context, @@ -540,10 +539,9 @@ class XMLStream(asyncio.BaseProtocol): ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context, sock=self.socket, server_hostname=self.default_domain) - @asyncio.coroutine - def ssl_coro(): + async def ssl_coro(): try: - transp, prot = yield from ssl_connect_routine + transp, prot = await ssl_connect_routine except ssl.SSLError as e: log.debug('SSL: Unable to connect', exc_info=True) log.error('CERT: Invalid certificate trust chain.') @@ -671,8 +669,7 @@ class XMLStream(asyncio.BaseProtocol): idx += 1 return False - @asyncio.coroutine - def get_dns_records(self, domain, port=None): + async def get_dns_records(self, domain, port=None): """Get the DNS records for a domain. :param domain: The domain in question. @@ -684,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol): resolver = default_resolver(loop=self.loop) self.configure_dns(resolver, domain=domain, port=port) - result = yield from resolve(domain, port, + result = await resolve(domain, port, service=self.dns_service, resolver=resolver, use_ipv6=self.use_ipv6, @@ -692,8 +689,7 @@ class XMLStream(asyncio.BaseProtocol): loop=self.loop) return result - @asyncio.coroutine - def pick_dns_answer(self, domain, port=None): + async def pick_dns_answer(self, domain, port=None): """Pick a server and port from DNS answers. Gets DNS answers if none available. @@ -703,7 +699,7 @@ class XMLStream(asyncio.BaseProtocol): :param port: If the results don't include a port, use this one. """ if self.dns_answers is None: - dns_records = yield from self.get_dns_records(domain, port) + dns_records = await self.get_dns_records(domain, port) self.dns_answers = iter(dns_records) try: @@ -768,10 +764,9 @@ class XMLStream(asyncio.BaseProtocol): # If the callback is a coroutine, schedule it instead of # running it directly if asyncio.iscoroutinefunction(handler_callback): - @asyncio.coroutine - def handler_callback_routine(cb): + async def handler_callback_routine(cb): try: - yield from cb(data) + await cb(data) except Exception as e: if old_exception: old_exception(e) -- cgit v1.2.3