diff options
Diffstat (limited to 'slixmpp')
-rw-r--r-- | slixmpp/test/integration.py | 61 | ||||
-rw-r--r-- | slixmpp/xmlstream/xmlstream.py | 17 |
2 files changed, 77 insertions, 1 deletions
diff --git a/slixmpp/test/integration.py b/slixmpp/test/integration.py new file mode 100644 index 00000000..d15019cc --- /dev/null +++ b/slixmpp/test/integration.py @@ -0,0 +1,61 @@ +""" + Slixmpp: The Slick XMPP Library + Copyright (C) 2020 Mathieu Pasquet + This file is part of Slixmpp. + + See the file LICENSE for copying permission. +""" + +import asyncio +import os +try: + from unittest import IsolatedAsyncioTestCase +except ImportError: + # Python < 3.8 + # just to make sure the imports do not break, but + # not usable. + from unittest import TestCase as IsolatedAsyncioTestCase +from typing import ( + List, +) + +from slixmpp import JID +from slixmpp.clientxmpp import ClientXMPP + + +class SlixIntegration(IsolatedAsyncioTestCase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.clients = [] + self.addAsyncCleanup(self._destroy) + + def envjid(self, name): + """Get a JID from an env var""" + value = os.getenv(name) + return JID(value) + + def envstr(self, name): + """get a str from an env var""" + return os.getenv(name) + + def register_plugins(self, plugins: List[str]): + """Register plugins on all known clients""" + for plugin in plugins: + for client in self.clients: + client.register_plugin(plugin) + + def add_client(self, jid: JID, password: str): + """Register a new client""" + self.clients.append(ClientXMPP(jid, password)) + + async def connect_clients(self): + """Connect all clients""" + for client in self.clients: + client.connect() + await client.wait_until('session_start') + + async def _destroy(self): + """Kill all clients""" + for client in self.clients: + client.abort() diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index af494903..066d84df 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -12,7 +12,7 @@ :license: MIT, see LICENSE for more details """ -from typing import Optional, Set, Callable +from typing import Optional, Set, Callable, Any import functools import logging @@ -1130,3 +1130,18 @@ class XMLStream(asyncio.BaseProtocol): :param exception: An unhandled exception object. """ pass + + async def wait_until(self, event: str, timeout=30) -> Any: + """Utility method to wake on the next firing of an event. + (Registers a disposable handler on it) + + :param str event: Event to wait on. + :param int timeout: Timeout + """ + fut = asyncio.Future() + self.add_event_handler( + event, + fut.set_result, + disposable=True, + ) + return await asyncio.wait_for(fut, timeout) |