summaryrefslogtreecommitdiff
path: root/slixmpp/xmlstream/matcher/xmlmask.py
blob: d50b706e57a590a592ebdaa542e4cda097a604e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

# Slixmpp: The Slick XMPP Library
# Copyright (C) 2010  Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
import logging

from xml.parsers.expat import ExpatError

from slixmpp.xmlstream.stanzabase import ET
from slixmpp.xmlstream.matcher.base import MatcherBase


log = logging.getLogger(__name__)


class MatchXMLMask(MatcherBase):

    """
    The XMLMask matcher selects stanzas whose XML matches a given
    XML pattern, or mask. For example, message stanzas with body elements
    could be matched using the mask:

    .. code-block:: xml

        <message xmlns="jabber:client"><body /></message>

    Use of XMLMask is discouraged, and
    :class:`~slixmpp.xmlstream.matcher.xpath.MatchXPath` or
    :class:`~slixmpp.xmlstream.matcher.stanzapath.StanzaPath`
    should be used instead.

    :param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML
                     object or XML string to use as a mask.
    """

    def __init__(self, criteria, default_ns='jabber:client'):
        MatcherBase.__init__(self, criteria)
        if isinstance(criteria, str):
            self._criteria = ET.fromstring(self._criteria)
        self.default_ns = default_ns

    def setDefaultNS(self, ns):
        """Set the default namespace to use during comparisons.

        :param ns: The new namespace to use as the default.
        """
        self.default_ns = ns

    def match(self, xml):
        """Compare a stanza object or XML object against the stored XML mask.

        Overrides MatcherBase.match.

        :param xml: The stanza object or XML object to compare against.
        """
        if hasattr(xml, 'xml'):
            xml = xml.xml
        return self._mask_cmp(xml, self._criteria, True)

    def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'):
        """Compare an XML object against an XML mask.

        :param source: The :class:`~xml.etree.ElementTree.Element` XML object
                       to compare against the mask.
        :param mask: The :class:`~xml.etree.ElementTree.Element` XML object
                     serving as the mask.
        :param use_ns: Indicates if namespaces should be respected during
                       the comparison.
        :default_ns: The default namespace to apply to elements that
                     do not have a specified namespace.
                     Defaults to ``"__no_ns__"``.
        """
        if source is None:
            # If the element was not found. May happen during recursive calls.
            return False

        # Convert the mask to an XML object if it is a string.
        if not hasattr(mask, 'attrib'):
            try:
                mask = ET.fromstring(mask)
            except ExpatError:
                log.warning("Expat error: %s\nIn parsing: %s", '', mask)

        mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
        if source.tag not in [mask.tag, mask_ns_tag]:
            return False

        # If the mask includes text, compare it.
        if mask.text and source.text and \
           source.text.strip() != mask.text.strip():
            return False

        # Compare attributes. The stanza must include the attributes
        # defined by the mask, but may include others.
        for name, value in mask.attrib.items():
            if source.attrib.get(name, "__None__") != value:
                return False

        # Recursively check subelements.
        matched_elements = {}
        for subelement in mask:
            matched = False
            for other in source.findall(subelement.tag):
                matched_elements[other] = False
                if self._mask_cmp(other, subelement, use_ns):
                    if not matched_elements.get(other, False):
                        matched_elements[other] = True
                        matched = True
            if not matched:
                return False

        # Everything matches.
        return True