diff options
author | mathieui <mathieui@mathieui.net> | 2015-05-12 00:02:32 +0200 |
---|---|---|
committer | mathieui <mathieui@mathieui.net> | 2015-05-12 00:02:32 +0200 |
commit | a2852eb249d443e7aef4281bba5243db8a40c837 (patch) | |
tree | 95c355110f01a4531e3aa8b24b782e1367145796 | |
parent | f1e6d6b0a92d061683cb1d1cabceb7f90c859a73 (diff) | |
download | slixmpp-a2852eb249d443e7aef4281bba5243db8a40c837.tar.gz slixmpp-a2852eb249d443e7aef4281bba5243db8a40c837.tar.bz2 slixmpp-a2852eb249d443e7aef4281bba5243db8a40c837.tar.xz slixmpp-a2852eb249d443e7aef4281bba5243db8a40c837.zip |
Allow the use of a custom loop instead of asyncio.get_event_loop()
-rw-r--r-- | slixmpp/plugins/xep_0325/control.py | 2 | ||||
-rw-r--r-- | slixmpp/xmlstream/resolver.py | 18 | ||||
-rw-r--r-- | slixmpp/xmlstream/xmlstream.py | 58 |
3 files changed, 42 insertions, 36 deletions
diff --git a/slixmpp/plugins/xep_0325/control.py b/slixmpp/plugins/xep_0325/control.py index 81ed9039..0c6837f6 100644 --- a/slixmpp/plugins/xep_0325/control.py +++ b/slixmpp/plugins/xep_0325/control.py @@ -332,7 +332,7 @@ class XEP_0325(BasePlugin): self.sessions[session]["nodeDone"][node] = False for node in self.sessions[session]["node_list"]: - timer = asyncio.get_event_loop().call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node))) + timer = self.xmpp.loop.call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node))) self.sessions[session]["commTimers"][node] = timer self.nodes[node]['device'].set_control_fields(process_fields, session=session, callback=self._device_set_command_callback) diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py index a9c260f0..fb2c3d31 100644 --- a/slixmpp/xmlstream/resolver.py +++ b/slixmpp/xmlstream/resolver.py @@ -32,14 +32,14 @@ except ImportError as e: "Not all features will be available") -def default_resolver(): +def default_resolver(loop): """Return a basic DNS resolver object. :returns: A :class:`aiodns.DNSResolver` object if aiodns is available. Otherwise, ``None``. """ if AIODNS_AVAILABLE: - return aiodns.DNSResolver(loop=asyncio.get_event_loop(), + return aiodns.DNSResolver(loop=loop, tries=1, timeout=1.0) return None @@ -47,7 +47,7 @@ def default_resolver(): @asyncio.coroutine def resolve(host, port=None, service=None, proto='tcp', - resolver=None, use_ipv6=True, use_aiodns=True): + resolver=None, use_ipv6=True, use_aiodns=True, loop=None): """Peform DNS resolution for a given hostname. Resolution may perform SRV record lookups if a service and protocol @@ -97,7 +97,7 @@ def resolve(host, port=None, service=None, proto='tcp', log.debug("DNS: Use of IPv6 has been disabled.") if resolver is None and AIODNS_AVAILABLE and use_aiodns: - resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop()) + resolver = aiodns.DNSResolver(loop=loop) # An IPv6 literal is allowed to be enclosed in square brackets, but # the brackets must be stripped in order to process the literal; @@ -142,19 +142,19 @@ def resolve(host, port=None, service=None, proto='tcp', if use_ipv6: aaaa = yield from get_AAAA(host, resolver=resolver, - use_aiodns=use_aiodns) + use_aiodns=use_aiodns, loop=loop) for address in aaaa: results.append((host, address, port)) a = yield from get_A(host, resolver=resolver, - use_aiodns=use_aiodns) + use_aiodns=use_aiodns, loop=loop) for address in a: results.append((host, address, port)) return results @asyncio.coroutine -def get_A(host, resolver=None, use_aiodns=True): +def get_A(host, resolver=None, use_aiodns=True, loop=None): """Lookup DNS A records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -177,7 +177,6 @@ def get_A(host, resolver=None, use_aiodns=True): # If not using aiodns, attempt lookup using the OS level # getaddrinfo() method. if resolver is None or not use_aiodns: - loop = asyncio.get_event_loop() try: recs = yield from loop.getaddrinfo(host, None, family=socket.AF_INET, @@ -198,7 +197,7 @@ def get_A(host, resolver=None, use_aiodns=True): @asyncio.coroutine -def get_AAAA(host, resolver=None, use_aiodns=True): +def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): """Lookup DNS AAAA records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -224,7 +223,6 @@ def get_AAAA(host, resolver=None, use_aiodns=True): if not socket.has_ipv6: log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host) return [] - loop = asyncio.get_event_loop() try: recs = yield from loop.getaddrinfo(host, None, family=socket.AF_INET6, diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 71873e48..866368bd 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -116,6 +116,9 @@ class XMLStream(asyncio.BaseProtocol): self._der_cert = None + # The asyncio event loop + self._loop = None + #: The default port to return when querying DNS records. self.default_port = int(port) @@ -213,6 +216,16 @@ class XMLStream(asyncio.BaseProtocol): self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('session_start', self._start_keepalive) + @property + def loop(self): + if self._loop is None: + self._loop = asyncio.get_event_loop() + return self._loop + + @loop.setter + def loop(self, value): + self._loop = value + def new_id(self): """Generate and return a new stream ID in hexadecimal form. @@ -270,7 +283,6 @@ class XMLStream(asyncio.BaseProtocol): @asyncio.coroutine def _connect_routine(self): - loop = asyncio.get_event_loop() self.event_when_connected = "connected" try: @@ -290,10 +302,10 @@ class XMLStream(asyncio.BaseProtocol): self.dns_answers = None try: - yield from loop.create_connection(lambda: self, - self.address[0], - self.address[1], - ssl=self.use_ssl) + yield from self.loop.create_connection(lambda: self, + self.address[0], + self.address[1], + ssl=self.use_ssl) except Socket.gaierror as e: self.event('connection_failed', 'No DNS record available for %s' % self.default_domain) @@ -309,17 +321,16 @@ class XMLStream(asyncio.BaseProtocol): function will run forever. If timeout is a number, this function will return after the given time in seconds. """ - loop = asyncio.get_event_loop() if timeout is None: if forever: - loop.run_forever() + self.loop.run_forever() else: - loop.run_until_complete(self.disconnected) + self.loop.run_until_complete(self.disconnected) else: tasks = [asyncio.sleep(timeout)] if not forever: tasks.append(self.disconnected) - loop.run_until_complete(asyncio.wait(tasks)) + self.loop.run_until_complete(asyncio.wait(tasks)) def init_parser(self): """init the XML parser. The parser must always be reset for each new @@ -367,8 +378,7 @@ class XMLStream(asyncio.BaseProtocol): elif self.xml_depth == 1: # A stanza is an XML element that is a direct child of # the root element, hence the check of depth == 1 - asyncio.get_event_loop().\ - idle_call(functools.partial(self.__spawn_event, xml)) + self.loop.idle_call(functools.partial(self.__spawn_event, xml)) if self.xml_root is not None: # Keep the root element empty of children to # save on memory use. @@ -461,7 +471,6 @@ class XMLStream(asyncio.BaseProtocol): If the handshake is successful, the XML stream will need to be restarted. """ - loop = asyncio.get_event_loop() self.event_when_connected = "tls_success" if self.ciphers is not None: @@ -478,9 +487,9 @@ class XMLStream(asyncio.BaseProtocol): self.ssl_context.verify_mode = ssl.CERT_REQUIRED self.ssl_context.load_verify_locations(cafile=self.ca_certs) - ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context, - sock=self.socket, - server_hostname=self.address[0]) + ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=self.ssl_context, + sock=self.socket, + server_hostname=self.address[0]) @asyncio.coroutine def ssl_coro(): try: @@ -621,14 +630,15 @@ class XMLStream(asyncio.BaseProtocol): if port is None: port = self.default_port - resolver = default_resolver() + resolver = default_resolver(loop=self.loop) self.configure_dns(resolver, domain=domain, port=port) result = yield from resolve(domain, port, service=self.dns_service, resolver=resolver, use_ipv6=self.use_ipv6, - use_aiodns=self.use_aiodns) + use_aiodns=self.use_aiodns, + loop=self.loop) return result @asyncio.coroutine @@ -746,14 +756,13 @@ class XMLStream(asyncio.BaseProtocol): """ if seconds is None: seconds = RESPONSE_TIMEOUT - loop = asyncio.get_event_loop() cb = functools.partial(callback, *args, **kwargs) if repeat: - handle = loop.call_later(seconds, self._execute_and_reschedule, - name, cb, seconds) + handle = self.loop.call_later(seconds, self._execute_and_reschedule, + name, cb, seconds) else: - handle = loop.call_later(seconds, self._execute_and_unschedule, - name, cb) + handle = self.loop.call_later(seconds, self._execute_and_unschedule, + name, cb) # Save that handle, so we can just cancel this scheduled event by # canceling scheduled_events[name] @@ -778,9 +787,8 @@ class XMLStream(asyncio.BaseProtocol): be called after the given number of seconds. """ self._safe_cb_run(name, cb) - loop = asyncio.get_event_loop() - handle = loop.call_later(seconds, self._execute_and_reschedule, - name, cb, seconds) + handle = self.loop.call_later(seconds, self._execute_and_reschedule, + name, cb, seconds) self.scheduled_events[name] = handle def _execute_and_unschedule(self, name, cb): |