summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/end_to_end/__main__.py29
1 files changed, 26 insertions, 3 deletions
diff --git a/tests/end_to_end/__main__.py b/tests/end_to_end/__main__.py
index f498585..c45e8b9 100644
--- a/tests/end_to_end/__main__.py
+++ b/tests/end_to_end/__main__.py
@@ -21,6 +21,17 @@ class MatchAll(MatcherBase):
class StanzaError(Exception):
+ """
+ Raised when a step fails.
+ """
+ pass
+
+
+class SkipStepError(Exception):
+ """
+ Raised by a step when it needs to be skiped, by running
+ the next available step immediately.
+ """
pass
@@ -69,6 +80,10 @@ class XMPPComponent(slixmpp.BaseXMPP):
self.stanza_checker(stanza)
except StanzaError as e:
self.error(e)
+ except SkipStepError:
+ # Run the next step and then re-handle this same stanza
+ self.run_scenario()
+ return self.handle_incoming_stanza(stanza)
self.stanza_checker = None
self.run_scenario()
@@ -97,6 +112,13 @@ def check_xpath(xpaths, stanza):
raise StanzaError("Received stanza ā€œ%sā€ did not match expected xpath ā€œ%sā€" % (stanza, xpath))
+def check_xpath_optional(xpaths, stanza):
+ try:
+ check_xpath(xpaths, stanza)
+ except StanzaError:
+ raise SkipStepError()
+
+
class Scenario:
"""Defines a list of actions that are executed in sequence, until one of
them throws an exception, or until the end. An action can be something
@@ -170,11 +192,12 @@ def send_stanza(stanza, xmpp, biboumi):
asyncio.get_event_loop().call_soon(xmpp.run_scenario)
-def expect_stanza(xpaths, xmpp, biboumi):
+def expect_stanza(xpaths, xmpp, biboumi, optional=False):
+ check_func = check_xpath if not optional else check_xpath_optional
if isinstance(xpaths, str):
- xmpp.stanza_checker = partial(check_xpath, [xpaths.format_map(common_replacements)])
+ xmpp.stanza_checker = partial(check_func, [xpaths.format_map(common_replacements)])
elif isinstance(xpaths, tuple):
- xmpp.stanza_checker = partial(check_xpath, [xpath.format_map(common_replacements) for xpath in xpaths])
+ xmpp.stanza_checker = partial(check_func, [xpath.format_map(common_replacements) for xpath in xpaths])
else:
print("Warning, from argument type passed to expect_stanza: %s" % (type(xpaths)))