summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieui <mathieui@mathieui.net>2015-05-12 00:02:32 +0200
committermathieui <mathieui@mathieui.net>2015-05-12 00:02:32 +0200
commita2852eb249d443e7aef4281bba5243db8a40c837 (patch)
tree95c355110f01a4531e3aa8b24b782e1367145796
parentf1e6d6b0a92d061683cb1d1cabceb7f90c859a73 (diff)
downloadslixmpp-a2852eb249d443e7aef4281bba5243db8a40c837.tar.gz
slixmpp-a2852eb249d443e7aef4281bba5243db8a40c837.tar.bz2
slixmpp-a2852eb249d443e7aef4281bba5243db8a40c837.tar.xz
slixmpp-a2852eb249d443e7aef4281bba5243db8a40c837.zip
Allow the use of a custom loop instead of asyncio.get_event_loop()
-rw-r--r--slixmpp/plugins/xep_0325/control.py2
-rw-r--r--slixmpp/xmlstream/resolver.py18
-rw-r--r--slixmpp/xmlstream/xmlstream.py58
3 files changed, 42 insertions, 36 deletions
diff --git a/slixmpp/plugins/xep_0325/control.py b/slixmpp/plugins/xep_0325/control.py
index 81ed9039..0c6837f6 100644
--- a/slixmpp/plugins/xep_0325/control.py
+++ b/slixmpp/plugins/xep_0325/control.py
@@ -332,7 +332,7 @@ class XEP_0325(BasePlugin):
self.sessions[session]["nodeDone"][node] = False
for node in self.sessions[session]["node_list"]:
- timer = asyncio.get_event_loop().call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
+ timer = self.xmpp.loop.call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
self.sessions[session]["commTimers"][node] = timer
self.nodes[node]['device'].set_control_fields(process_fields, session=session, callback=self._device_set_command_callback)
diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py
index a9c260f0..fb2c3d31 100644
--- a/slixmpp/xmlstream/resolver.py
+++ b/slixmpp/xmlstream/resolver.py
@@ -32,14 +32,14 @@ except ImportError as e:
"Not all features will be available")
-def default_resolver():
+def default_resolver(loop):
"""Return a basic DNS resolver object.
:returns: A :class:`aiodns.DNSResolver` object if aiodns
is available. Otherwise, ``None``.
"""
if AIODNS_AVAILABLE:
- return aiodns.DNSResolver(loop=asyncio.get_event_loop(),
+ return aiodns.DNSResolver(loop=loop,
tries=1,
timeout=1.0)
return None
@@ -47,7 +47,7 @@ def default_resolver():
@asyncio.coroutine
def resolve(host, port=None, service=None, proto='tcp',
- resolver=None, use_ipv6=True, use_aiodns=True):
+ resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
"""Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol
@@ -97,7 +97,7 @@ def resolve(host, port=None, service=None, proto='tcp',
log.debug("DNS: Use of IPv6 has been disabled.")
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
- resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop())
+ resolver = aiodns.DNSResolver(loop=loop)
# An IPv6 literal is allowed to be enclosed in square brackets, but
# the brackets must be stripped in order to process the literal;
@@ -142,19 +142,19 @@ def resolve(host, port=None, service=None, proto='tcp',
if use_ipv6:
aaaa = yield from get_AAAA(host, resolver=resolver,
- use_aiodns=use_aiodns)
+ use_aiodns=use_aiodns, loop=loop)
for address in aaaa:
results.append((host, address, port))
a = yield from get_A(host, resolver=resolver,
- use_aiodns=use_aiodns)
+ use_aiodns=use_aiodns, loop=loop)
for address in a:
results.append((host, address, port))
return results
@asyncio.coroutine
-def get_A(host, resolver=None, use_aiodns=True):
+def get_A(host, resolver=None, use_aiodns=True, loop=None):
"""Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -177,7 +177,6 @@ def get_A(host, resolver=None, use_aiodns=True):
# If not using aiodns, attempt lookup using the OS level
# getaddrinfo() method.
if resolver is None or not use_aiodns:
- loop = asyncio.get_event_loop()
try:
recs = yield from loop.getaddrinfo(host, None,
family=socket.AF_INET,
@@ -198,7 +197,7 @@ def get_A(host, resolver=None, use_aiodns=True):
@asyncio.coroutine
-def get_AAAA(host, resolver=None, use_aiodns=True):
+def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
"""Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -224,7 +223,6 @@ def get_AAAA(host, resolver=None, use_aiodns=True):
if not socket.has_ipv6:
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return []
- loop = asyncio.get_event_loop()
try:
recs = yield from loop.getaddrinfo(host, None,
family=socket.AF_INET6,
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 71873e48..866368bd 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -116,6 +116,9 @@ class XMLStream(asyncio.BaseProtocol):
self._der_cert = None
+ # The asyncio event loop
+ self._loop = None
+
#: The default port to return when querying DNS records.
self.default_port = int(port)
@@ -213,6 +216,16 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive)
+ @property
+ def loop(self):
+ if self._loop is None:
+ self._loop = asyncio.get_event_loop()
+ return self._loop
+
+ @loop.setter
+ def loop(self, value):
+ self._loop = value
+
def new_id(self):
"""Generate and return a new stream ID in hexadecimal form.
@@ -270,7 +283,6 @@ class XMLStream(asyncio.BaseProtocol):
@asyncio.coroutine
def _connect_routine(self):
- loop = asyncio.get_event_loop()
self.event_when_connected = "connected"
try:
@@ -290,10 +302,10 @@ class XMLStream(asyncio.BaseProtocol):
self.dns_answers = None
try:
- yield from loop.create_connection(lambda: self,
- self.address[0],
- self.address[1],
- ssl=self.use_ssl)
+ yield from self.loop.create_connection(lambda: self,
+ self.address[0],
+ self.address[1],
+ ssl=self.use_ssl)
except Socket.gaierror as e:
self.event('connection_failed',
'No DNS record available for %s' % self.default_domain)
@@ -309,17 +321,16 @@ class XMLStream(asyncio.BaseProtocol):
function will run forever. If timeout is a number, this function
will return after the given time in seconds.
"""
- loop = asyncio.get_event_loop()
if timeout is None:
if forever:
- loop.run_forever()
+ self.loop.run_forever()
else:
- loop.run_until_complete(self.disconnected)
+ self.loop.run_until_complete(self.disconnected)
else:
tasks = [asyncio.sleep(timeout)]
if not forever:
tasks.append(self.disconnected)
- loop.run_until_complete(asyncio.wait(tasks))
+ self.loop.run_until_complete(asyncio.wait(tasks))
def init_parser(self):
"""init the XML parser. The parser must always be reset for each new
@@ -367,8 +378,7 @@ class XMLStream(asyncio.BaseProtocol):
elif self.xml_depth == 1:
# A stanza is an XML element that is a direct child of
# the root element, hence the check of depth == 1
- asyncio.get_event_loop().\
- idle_call(functools.partial(self.__spawn_event, xml))
+ self.loop.idle_call(functools.partial(self.__spawn_event, xml))
if self.xml_root is not None:
# Keep the root element empty of children to
# save on memory use.
@@ -461,7 +471,6 @@ class XMLStream(asyncio.BaseProtocol):
If the handshake is successful, the XML stream will need
to be restarted.
"""
- loop = asyncio.get_event_loop()
self.event_when_connected = "tls_success"
if self.ciphers is not None:
@@ -478,9 +487,9 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.load_verify_locations(cafile=self.ca_certs)
- ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context,
- sock=self.socket,
- server_hostname=self.address[0])
+ ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=self.ssl_context,
+ sock=self.socket,
+ server_hostname=self.address[0])
@asyncio.coroutine
def ssl_coro():
try:
@@ -621,14 +630,15 @@ class XMLStream(asyncio.BaseProtocol):
if port is None:
port = self.default_port
- resolver = default_resolver()
+ resolver = default_resolver(loop=self.loop)
self.configure_dns(resolver, domain=domain, port=port)
result = yield from resolve(domain, port,
service=self.dns_service,
resolver=resolver,
use_ipv6=self.use_ipv6,
- use_aiodns=self.use_aiodns)
+ use_aiodns=self.use_aiodns,
+ loop=self.loop)
return result
@asyncio.coroutine
@@ -746,14 +756,13 @@ class XMLStream(asyncio.BaseProtocol):
"""
if seconds is None:
seconds = RESPONSE_TIMEOUT
- loop = asyncio.get_event_loop()
cb = functools.partial(callback, *args, **kwargs)
if repeat:
- handle = loop.call_later(seconds, self._execute_and_reschedule,
- name, cb, seconds)
+ handle = self.loop.call_later(seconds, self._execute_and_reschedule,
+ name, cb, seconds)
else:
- handle = loop.call_later(seconds, self._execute_and_unschedule,
- name, cb)
+ handle = self.loop.call_later(seconds, self._execute_and_unschedule,
+ name, cb)
# Save that handle, so we can just cancel this scheduled event by
# canceling scheduled_events[name]
@@ -778,9 +787,8 @@ class XMLStream(asyncio.BaseProtocol):
be called after the given number of seconds.
"""
self._safe_cb_run(name, cb)
- loop = asyncio.get_event_loop()
- handle = loop.call_later(seconds, self._execute_and_reschedule,
- name, cb, seconds)
+ handle = self.loop.call_later(seconds, self._execute_and_reschedule,
+ name, cb, seconds)
self.scheduled_events[name] = handle
def _execute_and_unschedule(self, name, cb):