summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream/matcher/xmlmask.py
blob: 87433d911099c6c280cf77223f46a7de51a636b2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
    SleekXMPP: The Sleek XMPP Library
    Copyright (C) 2010  Nathanael C. Fritz
    This file is part of SleekXMPP.

    See the file license.txt for copying permission.
"""
from . import base
from xml.etree import cElementTree
from xml.parsers.expat import ExpatError
import logging

ignore_ns = False

class MatchXMLMask(base.MatcherBase):

	def __init__(self, criteria):
		base.MatcherBase.__init__(self, criteria)
		if type(criteria) == type(''):
			self._criteria = cElementTree.fromstring(self._criteria)
		self.default_ns = 'jabber:client'
	
	def setDefaultNS(self, ns):
		self.default_ns = ns

	def match(self, xml):
		if hasattr(xml, 'xml'):
			xml = xml.xml
		return self.maskcmp(xml, self._criteria, True)
	
	def maskcmp(self, source, maskobj, use_ns=False, default_ns='__no_ns__'):
		"""maskcmp(xmlobj, maskobj):
		Compare etree xml object to etree xml object mask"""
		use_ns = not ignore_ns
		#TODO require namespaces
		if source == None: #if element not found (happens during recursive check below)
			return False
		if not hasattr(maskobj, 'attrib'): #if the mask is a string, make it an xml obj
			try:
				maskobj = cElementTree.fromstring(maskobj)
			except ExpatError:
                logging.exception( "Expat error parsing: %s", maskobj)
		if not use_ns and source.tag.split('}', 1)[-1] != maskobj.tag.split('}', 1)[-1]: # strip off ns and compare
			return False
		if use_ns and (source.tag != maskobj.tag and "{%s}%s" % (self.default_ns, maskobj.tag) != source.tag ):
			return False
		if maskobj.text and source.text != maskobj.text:
			return False
		for attr_name in maskobj.attrib: #compare attributes
			if source.attrib.get(attr_name, "__None__") != maskobj.attrib[attr_name]:
				return False
		#for subelement in maskobj.getiterator()[1:]: #recursively compare subelements
		for subelement in maskobj: #recursively compare subelements
			if use_ns:
				if not self.maskcmp(source.find(subelement.tag), subelement, use_ns):
					return False
			else:
				if not self.maskcmp(self.getChildIgnoreNS(source, subelement.tag), subelement, use_ns):
					return False
		return True
	
	def getChildIgnoreNS(self, xml, tag):
		tag = tag.split('}')[-1]
		try:
			idx = [c.tag.split('}')[-1] for c in xml.getchildren()].index(tag)
		except ValueError:
			return None
		return xml.getchildren()[idx]