From 403b1802ecb9d40799b00073bd15baf0d0fd8349 Mon Sep 17 00:00:00 2001 From: Lance Stout Date: Thu, 24 Jan 2013 02:43:46 -0800 Subject: Update tostring to inject xmlns definitions when needed. --- sleekxmpp/xmlstream/tostring.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) (limited to 'sleekxmpp') diff --git a/sleekxmpp/xmlstream/tostring.py b/sleekxmpp/xmlstream/tostring.py index 08d7ad02..f4157f7a 100644 --- a/sleekxmpp/xmlstream/tostring.py +++ b/sleekxmpp/xmlstream/tostring.py @@ -24,8 +24,8 @@ if sys.version_info < (3, 0): XML_NS = 'http://www.w3.org/XML/1998/namespace' -def tostring(xml=None, xmlns='', stream=None, - outbuffer='', top_level=False, open_only=False): +def tostring(xml=None, xmlns='', stream=None, outbuffer='', + top_level=False, open_only=False, namespaces=None): """Serialize an XML object to a Unicode string. If an outer xmlns is provided using ``xmlns``, then the current element's @@ -41,7 +41,8 @@ def tostring(xml=None, xmlns='', stream=None, during recursive calls. :param bool top_level: Indicates that the element is the outermost element. - + :param set namespaces: Track which namespaces are in active use so + that new ones can be declared when needed. :type xml: :py:class:`~xml.etree.ElementTree.Element` :type stream: :class:`~sleekxmpp.xmlstream.xmlstream.XMLStream` @@ -63,6 +64,7 @@ def tostring(xml=None, xmlns='', stream=None, default_ns = '' stream_ns = '' use_cdata = False + if stream: default_ns = stream.default_ns stream_ns = stream.stream_ns @@ -82,6 +84,7 @@ def tostring(xml=None, xmlns='', stream=None, output.append(namespace) # Output escaped attribute values. + new_namespaces = set() for attrib, value in xml.attrib.items(): value = escape(value, use_cdata) if '}' not in attrib: @@ -92,9 +95,15 @@ def tostring(xml=None, xmlns='', stream=None, if stream and attrib_ns in stream.namespace_map: mapped_ns = stream.namespace_map[attrib_ns] if mapped_ns: - output.append(' %s:%s="%s"' % (mapped_ns, - attrib, - value)) + if namespaces is None: + namespaces = set() + if attrib_ns not in namespaces: + namespaces.add(attrib_ns) + new_namespaces.add(attrib_ns) + output.append(' xmlns:%s="%s"' % ( + mapped_ns, attrib_ns)) + output.append(' %s:%s="%s"' % ( + mapped_ns, attrib, value)) elif attrib_ns == XML_NS: output.append(' xml:%s="%s"' % (attrib, value)) @@ -110,7 +119,8 @@ def tostring(xml=None, xmlns='', stream=None, output.append(escape(xml.text, use_cdata)) if len(xml): for child in xml: - output.append(tostring(child, tag_xmlns, stream)) + output.append(tostring(child, tag_xmlns, stream, + namespaces=namespaces)) output.append("" % tag_name) elif xml.text: # If we only have text content. @@ -121,6 +131,11 @@ def tostring(xml=None, xmlns='', stream=None, if xml.tail: # If there is additional text after the element. output.append(escape(xml.tail, use_cdata)) + for ns in new_namespaces: + # Remove namespaces introduced in this context. This is necessary + # because the namespaces object continues to be shared with other + # contexts. + namespaces.remove(ns) return ''.join(output) -- cgit v1.2.3