diff options
Diffstat (limited to 'sleekxmpp/xmlstream/matcher/xmlmask.py')
-rw-r--r-- | sleekxmpp/xmlstream/matcher/xmlmask.py | 204 |
1 files changed, 146 insertions, 58 deletions
diff --git a/sleekxmpp/xmlstream/matcher/xmlmask.py b/sleekxmpp/xmlstream/matcher/xmlmask.py index 89fd6422..2967a2af 100644 --- a/sleekxmpp/xmlstream/matcher/xmlmask.py +++ b/sleekxmpp/xmlstream/matcher/xmlmask.py @@ -5,63 +5,151 @@ See the file LICENSE for copying permission. """ -from . import base -from xml.etree import cElementTree + from xml.parsers.expat import ExpatError -ignore_ns = False - -class MatchXMLMask(base.MatcherBase): - - def __init__(self, criteria): - base.MatcherBase.__init__(self, criteria) - if type(criteria) == type(''): - self._criteria = cElementTree.fromstring(self._criteria) - self.default_ns = 'jabber:client' - - def setDefaultNS(self, ns): - self.default_ns = ns - - def match(self, xml): - if hasattr(xml, 'xml'): - xml = xml.xml - return self.maskcmp(xml, self._criteria, True) - - def maskcmp(self, source, maskobj, use_ns=False, default_ns='__no_ns__'): - """maskcmp(xmlobj, maskobj): - Compare etree xml object to etree xml object mask""" - use_ns = not ignore_ns - #TODO require namespaces - if source == None: #if element not found (happens during recursive check below) - return False - if not hasattr(maskobj, 'attrib'): #if the mask is a string, make it an xml obj - try: - maskobj = cElementTree.fromstring(maskobj) - except ExpatError: - logging.log(logging.WARNING, "Expat error: %s\nIn parsing: %s" % ('', maskobj)) - if not use_ns and source.tag.split('}', 1)[-1] != maskobj.tag.split('}', 1)[-1]: # strip off ns and compare - return False - if use_ns and (source.tag != maskobj.tag and "{%s}%s" % (self.default_ns, maskobj.tag) != source.tag ): - return False - if maskobj.text and source.text != maskobj.text: - return False - for attr_name in maskobj.attrib: #compare attributes - if source.attrib.get(attr_name, "__None__") != maskobj.attrib[attr_name]: - return False - #for subelement in maskobj.getiterator()[1:]: #recursively compare subelements - for subelement in maskobj: #recursively compare subelements - if use_ns: - if not self.maskcmp(source.find(subelement.tag), subelement, use_ns): - return False - else: - if not self.maskcmp(self.getChildIgnoreNS(source, subelement.tag), subelement, use_ns): - return False - return True - - def getChildIgnoreNS(self, xml, tag): - tag = tag.split('}')[-1] - try: - idx = [c.tag.split('}')[-1] for c in xml.getchildren()].index(tag) - except ValueError: - return None - return xml.getchildren()[idx] +from sleekxmpp.xmlstream.stanzabase import ET +from sleekxmpp.xmlstream.matcher.base import MatcherBase + + +# Flag indicating if the builtin XPath matcher should be used, which +# uses namespaces, or a custom matcher that ignores namespaces. +# Changing this will affect ALL XMLMask matchers. +IGNORE_NS = False + + +class MatchXMLMask(MatcherBase): + + """ + The XMLMask matcher selects stanzas whose XML matches a given + XML pattern, or mask. For example, message stanzas with body elements + could be matched using the mask: + + <message xmlns="jabber:client"><body /></message> + + Use of XMLMask is discouraged, and XPath or StanzaPath should be used + instead. + + The use of namespaces in the mask comparison is controlled by + IGNORE_NS. Setting IGNORE_NS to True will disable namespace based matching + for ALL XMLMask matchers. + + Methods: + match -- Overrides MatcherBase.match. + setDefaultNS -- Set the default namespace for the mask. + """ + + def __init__(self, criteria): + """ + Create a new XMLMask matcher. + + Arguments: + criteria -- Either an XML object or XML string to use as a mask. + """ + MatcherBase.__init__(self, criteria) + if isinstance(criteria, str): + self._criteria = ET.fromstring(self._criteria) + self.default_ns = 'jabber:client' + + def setDefaultNS(self, ns): + """ + Set the default namespace to use during comparisons. + + Arguments: + ns -- The new namespace to use as the default. + """ + self.default_ns = ns + + def match(self, xml): + """ + Compare a stanza object or XML object against the stored XML mask. + + Overrides MatcherBase.match. + + Arguments: + xml -- The stanza object or XML object to compare against. + """ + if hasattr(xml, 'xml'): + xml = xml.xml + return self._mask_cmp(xml, self._criteria, True) + + def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'): + """ + Compare an XML object against an XML mask. + + Arguments: + source -- The XML object to compare against the mask. + mask -- The XML object serving as the mask. + use_ns -- Indicates if namespaces should be respected during + the comparison. + default_ns -- The default namespace to apply to elements that + do not have a specified namespace. + Defaults to "__no_ns__". + """ + use_ns = not IGNORE_NS + + if source is None: + # If the element was not found. May happend during recursive calls. + return False + + # Convert the mask to an XML object if it is a string. + if not hasattr(mask, 'attrib'): + try: + mask = ET.fromstring(mask) + except ExpatError: + logging.log(logging.WARNING, + "Expat error: %s\nIn parsing: %s" % ('', mask)) + + if not use_ns: + # Compare the element without using namespaces. + source_tag = source.tag.split('}', 1)[-1] + mask_tag = mask.tag.split('}', 1)[-1] + if source_tag != mask_tag: + return False + else: + # Compare the element using namespaces + mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag) + if source.tag not in [mask.tag, mask_ns_tag]: + return False + + # If the mask includes text, compare it. + if mask.text and source.text != mask.text: + return False + + # Compare attributes. The stanza must include the attributes + # defined by the mask, but may include others. + for name, value in mask.attrib.items(): + if source.attrib.get(name, "__None__") != value: + return False + + # Recursively check subelements. + for subelement in mask: + if use_ns: + if not self._mask_cmp(source.find(subelement.tag), + subelement, use_ns): + return False + else: + if not self._mask_cmp(self._get_child(source, subelement.tag), + subelement, use_ns): + return False + + # Everything matches. + return True + + def _get_child(self, xml, tag): + """ + Return a child element given its tag, ignoring namespace values. + + Returns None if the child was not found. + + Arguments: + xml -- The XML object to search for the given child tag. + tag -- The name of the subelement to find. + """ + tag = tag.split('}')[-1] + try: + children = [c.tag.split('}')[-1] for c in xml.getchildren()] + index = children.index(tag) + except ValueError: + return None + return xml.getchildren()[index] |