summaryrefslogtreecommitdiff
path: root/slixmpp/xmlstream/matcher/xmlmask.py
diff options
context:
space:
mode:
Diffstat (limited to 'slixmpp/xmlstream/matcher/xmlmask.py')
-rw-r--r--slixmpp/xmlstream/matcher/xmlmask.py28
1 files changed, 11 insertions, 17 deletions
diff --git a/slixmpp/xmlstream/matcher/xmlmask.py b/slixmpp/xmlstream/matcher/xmlmask.py
index d50b706e..b63e0f05 100644
--- a/slixmpp/xmlstream/matcher/xmlmask.py
+++ b/slixmpp/xmlstream/matcher/xmlmask.py
@@ -1,4 +1,3 @@
-
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
@@ -6,8 +5,9 @@
import logging
from xml.parsers.expat import ExpatError
+from xml.etree.ElementTree import Element
-from slixmpp.xmlstream.stanzabase import ET
+from slixmpp.xmlstream.stanzabase import ET, StanzaBase
from slixmpp.xmlstream.matcher.base import MatcherBase
@@ -33,32 +33,33 @@ class MatchXMLMask(MatcherBase):
:param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML
object or XML string to use as a mask.
"""
+ _criteria: Element
- def __init__(self, criteria, default_ns='jabber:client'):
+ def __init__(self, criteria: str, default_ns: str = 'jabber:client'):
MatcherBase.__init__(self, criteria)
if isinstance(criteria, str):
- self._criteria = ET.fromstring(self._criteria)
+ self._criteria = ET.fromstring(criteria)
self.default_ns = default_ns
- def setDefaultNS(self, ns):
+ def setDefaultNS(self, ns: str) -> None:
"""Set the default namespace to use during comparisons.
:param ns: The new namespace to use as the default.
"""
self.default_ns = ns
- def match(self, xml):
+ def match(self, xml: StanzaBase) -> bool:
"""Compare a stanza object or XML object against the stored XML mask.
Overrides MatcherBase.match.
:param xml: The stanza object or XML object to compare against.
"""
- if hasattr(xml, 'xml'):
- xml = xml.xml
- return self._mask_cmp(xml, self._criteria, True)
+ real_xml = xml.xml
+ return self._mask_cmp(real_xml, self._criteria, True)
- def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'):
+ def _mask_cmp(self, source: Element, mask: Element, use_ns: bool = False,
+ default_ns: str = '__no_ns__') -> bool:
"""Compare an XML object against an XML mask.
:param source: The :class:`~xml.etree.ElementTree.Element` XML object
@@ -75,13 +76,6 @@ class MatchXMLMask(MatcherBase):
# If the element was not found. May happen during recursive calls.
return False
- # Convert the mask to an XML object if it is a string.
- if not hasattr(mask, 'attrib'):
- try:
- mask = ET.fromstring(mask)
- except ExpatError:
- log.warning("Expat error: %s\nIn parsing: %s", '', mask)
-
mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
if source.tag not in [mask.tag, mask_ns_tag]:
return False