diff options
Diffstat (limited to 'slixmpp/xmlstream/handler/coroutine_callback.py')
-rw-r--r-- | slixmpp/xmlstream/handler/coroutine_callback.py | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/slixmpp/xmlstream/handler/coroutine_callback.py b/slixmpp/xmlstream/handler/coroutine_callback.py index 6568ba9f..d41cd7ba 100644 --- a/slixmpp/xmlstream/handler/coroutine_callback.py +++ b/slixmpp/xmlstream/handler/coroutine_callback.py @@ -4,8 +4,19 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details +from __future__ import annotations + +from asyncio import iscoroutinefunction, ensure_future +from typing import Optional, Callable, Awaitable, TYPE_CHECKING + +from slixmpp.xmlstream.stanzabase import StanzaBase from slixmpp.xmlstream.handler.base import BaseHandler -from slixmpp.xmlstream.asyncio import asyncio +from slixmpp.xmlstream.matcher.base import MatcherBase + +CoroutineFunction = Callable[[StanzaBase], Awaitable[None]] + +if TYPE_CHECKING: + from slixmpp.xmlstream.xmlstream import XMLStream class CoroutineCallback(BaseHandler): @@ -34,15 +45,20 @@ class CoroutineCallback(BaseHandler): instance this handler should monitor. """ - def __init__(self, name, matcher, pointer, once=False, - instream=False, stream=None): + _once: bool + _instream: bool + _pointer: CoroutineFunction + + def __init__(self, name: str, matcher: MatcherBase, + pointer: CoroutineFunction, once: bool = False, + instream: bool = False, stream: Optional[XMLStream] = None): BaseHandler.__init__(self, name, matcher, stream) - if not asyncio.iscoroutinefunction(pointer): + if not iscoroutinefunction(pointer): raise ValueError("Given function is not a coroutine") - async def pointer_wrapper(stanza, *args, **kwargs): + async def pointer_wrapper(stanza: StanzaBase) -> None: try: - await pointer(stanza, *args, **kwargs) + await pointer(stanza) except Exception as e: stanza.exception(e) @@ -50,29 +66,29 @@ class CoroutineCallback(BaseHandler): self._once = once self._instream = instream - def prerun(self, payload): + def prerun(self, payload: StanzaBase) -> None: """Execute the callback during stream processing, if the callback was created with ``instream=True``. :param payload: The matched - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. """ if self._once: self._destroy = True if self._instream: self.run(payload, True) - def run(self, payload, instream=False): + def run(self, payload: StanzaBase, instream: bool = False) -> None: """Execute the callback function with the matched stanza payload. :param payload: The matched - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. :param bool instream: Force the handler to execute during stream processing. This should only be used by :meth:`prerun()`. Defaults to ``False``. """ if not self._instream or instream: - asyncio.ensure_future(self._pointer(payload)) + ensure_future(self._pointer(payload)) if self._once: self._destroy = True del self._pointer |