summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieui <mathieui@mathieui.net>2015-02-23 17:21:21 +0100
committermathieui <mathieui@mathieui.net>2015-02-23 17:32:31 +0100
commit8fd0d7c993a5c26663efb82c19d3f3e49539521a (patch)
treedd8d70be32b6a8b71899facbc5a6326dffdce15c
parent1450d3637729186efef8d61297cb03601879b63c (diff)
downloadslixmpp-8fd0d7c993a5c26663efb82c19d3f3e49539521a.tar.gz
slixmpp-8fd0d7c993a5c26663efb82c19d3f3e49539521a.tar.bz2
slixmpp-8fd0d7c993a5c26663efb82c19d3f3e49539521a.tar.xz
slixmpp-8fd0d7c993a5c26663efb82c19d3f3e49539521a.zip
Add a coroutine_wrapper decorator
This decorator checks for the coroutine=True keyword arg and wraps the result of the function call in a coroutine if it isn’t. This allows to have constructs like: @coroutine_wrapper def toto(xmpp, *, coroutine=False): if xmpp.cached: return xmpp.cached else: return xmpp.make_iq_get().send(coroutine=coroutine) @asyncio.coroutine def main(xmpp): result = yield from toto(xmpp, coroutine=True) xmpp.cached = result result2 = yield from toto(xmpp, coroutine=True) If the wrapper wasn’t there, the second fetch would fail. This decorator does not do anything if the coroutine argument is False.
-rw-r--r--slixmpp/xmlstream/asyncio.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/slixmpp/xmlstream/asyncio.py b/slixmpp/xmlstream/asyncio.py
index 4277868b..76195237 100644
--- a/slixmpp/xmlstream/asyncio.py
+++ b/slixmpp/xmlstream/asyncio.py
@@ -8,6 +8,7 @@ call_soon() ones. These callback are called only once each.
import asyncio
from asyncio import events
+from functools import wraps
import collections
@@ -32,3 +33,23 @@ cls.idle_call = idle_call
real_run_once = cls._run_once
cls._run_once = my_run_once
+
+def coroutine_wrapper(func):
+ """
+ Make sure the result of a function call is a coroutine
+ if the ``coroutine`` keyword argument is true.
+ """
+ def wrap_coro(result):
+ if asyncio.iscoroutinefunction(result):
+ return result
+ else:
+ return asyncio.coroutine(lambda: result)()
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if kwargs.get('coroutine', False):
+ return wrap_coro(func(*args, **kwargs))
+ else:
+ return func(*args, **kwargs)
+
+ return wrapper