From a2852eb249d443e7aef4281bba5243db8a40c837 Mon Sep 17 00:00:00 2001 From: mathieui Date: Tue, 12 May 2015 00:02:32 +0200 Subject: Allow the use of a custom loop instead of asyncio.get_event_loop() --- slixmpp/xmlstream/xmlstream.py | 58 ++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 25 deletions(-) (limited to 'slixmpp/xmlstream/xmlstream.py') diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 71873e48..866368bd 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -116,6 +116,9 @@ class XMLStream(asyncio.BaseProtocol): self._der_cert = None + # The asyncio event loop + self._loop = None + #: The default port to return when querying DNS records. self.default_port = int(port) @@ -213,6 +216,16 @@ class XMLStream(asyncio.BaseProtocol): self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('session_start', self._start_keepalive) + @property + def loop(self): + if self._loop is None: + self._loop = asyncio.get_event_loop() + return self._loop + + @loop.setter + def loop(self, value): + self._loop = value + def new_id(self): """Generate and return a new stream ID in hexadecimal form. @@ -270,7 +283,6 @@ class XMLStream(asyncio.BaseProtocol): @asyncio.coroutine def _connect_routine(self): - loop = asyncio.get_event_loop() self.event_when_connected = "connected" try: @@ -290,10 +302,10 @@ class XMLStream(asyncio.BaseProtocol): self.dns_answers = None try: - yield from loop.create_connection(lambda: self, - self.address[0], - self.address[1], - ssl=self.use_ssl) + yield from self.loop.create_connection(lambda: self, + self.address[0], + self.address[1], + ssl=self.use_ssl) except Socket.gaierror as e: self.event('connection_failed', 'No DNS record available for %s' % self.default_domain) @@ -309,17 +321,16 @@ class XMLStream(asyncio.BaseProtocol): function will run forever. If timeout is a number, this function will return after the given time in seconds. """ - loop = asyncio.get_event_loop() if timeout is None: if forever: - loop.run_forever() + self.loop.run_forever() else: - loop.run_until_complete(self.disconnected) + self.loop.run_until_complete(self.disconnected) else: tasks = [asyncio.sleep(timeout)] if not forever: tasks.append(self.disconnected) - loop.run_until_complete(asyncio.wait(tasks)) + self.loop.run_until_complete(asyncio.wait(tasks)) def init_parser(self): """init the XML parser. The parser must always be reset for each new @@ -367,8 +378,7 @@ class XMLStream(asyncio.BaseProtocol): elif self.xml_depth == 1: # A stanza is an XML element that is a direct child of # the root element, hence the check of depth == 1 - asyncio.get_event_loop().\ - idle_call(functools.partial(self.__spawn_event, xml)) + self.loop.idle_call(functools.partial(self.__spawn_event, xml)) if self.xml_root is not None: # Keep the root element empty of children to # save on memory use. @@ -461,7 +471,6 @@ class XMLStream(asyncio.BaseProtocol): If the handshake is successful, the XML stream will need to be restarted. """ - loop = asyncio.get_event_loop() self.event_when_connected = "tls_success" if self.ciphers is not None: @@ -478,9 +487,9 @@ class XMLStream(asyncio.BaseProtocol): self.ssl_context.verify_mode = ssl.CERT_REQUIRED self.ssl_context.load_verify_locations(cafile=self.ca_certs) - ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context, - sock=self.socket, - server_hostname=self.address[0]) + ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=self.ssl_context, + sock=self.socket, + server_hostname=self.address[0]) @asyncio.coroutine def ssl_coro(): try: @@ -621,14 +630,15 @@ class XMLStream(asyncio.BaseProtocol): if port is None: port = self.default_port - resolver = default_resolver() + resolver = default_resolver(loop=self.loop) self.configure_dns(resolver, domain=domain, port=port) result = yield from resolve(domain, port, service=self.dns_service, resolver=resolver, use_ipv6=self.use_ipv6, - use_aiodns=self.use_aiodns) + use_aiodns=self.use_aiodns, + loop=self.loop) return result @asyncio.coroutine @@ -746,14 +756,13 @@ class XMLStream(asyncio.BaseProtocol): """ if seconds is None: seconds = RESPONSE_TIMEOUT - loop = asyncio.get_event_loop() cb = functools.partial(callback, *args, **kwargs) if repeat: - handle = loop.call_later(seconds, self._execute_and_reschedule, - name, cb, seconds) + handle = self.loop.call_later(seconds, self._execute_and_reschedule, + name, cb, seconds) else: - handle = loop.call_later(seconds, self._execute_and_unschedule, - name, cb) + handle = self.loop.call_later(seconds, self._execute_and_unschedule, + name, cb) # Save that handle, so we can just cancel this scheduled event by # canceling scheduled_events[name] @@ -778,9 +787,8 @@ class XMLStream(asyncio.BaseProtocol): be called after the given number of seconds. """ self._safe_cb_run(name, cb) - loop = asyncio.get_event_loop() - handle = loop.call_later(seconds, self._execute_and_reschedule, - name, cb, seconds) + handle = self.loop.call_later(seconds, self._execute_and_reschedule, + name, cb, seconds) self.scheduled_events[name] = handle def _execute_and_unschedule(self, name, cb): -- cgit v1.2.3