summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream/tostring.py
diff options
context:
space:
mode:
Diffstat (limited to 'sleekxmpp/xmlstream/tostring.py')
-rw-r--r--sleekxmpp/xmlstream/tostring.py32
1 files changed, 23 insertions, 9 deletions
diff --git a/sleekxmpp/xmlstream/tostring.py b/sleekxmpp/xmlstream/tostring.py
index 2480f9b2..f22e7770 100644
--- a/sleekxmpp/xmlstream/tostring.py
+++ b/sleekxmpp/xmlstream/tostring.py
@@ -63,9 +63,11 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None,
default_ns = ''
stream_ns = ''
+ use_cdata = False
if stream:
default_ns = stream.default_ns
stream_ns = stream.stream_ns
+ use_cdata = stream.use_cdata
# Output the tag name and derived namespace of the element.
namespace = ''
@@ -81,7 +83,7 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None,
# Output escaped attribute values.
for attrib, value in xml.attrib.items():
- value = xml_escape(value)
+ value = escape(value, use_cdata)
if '}' not in attrib:
output.append(' %s="%s"' % (attrib, value))
else:
@@ -105,24 +107,24 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None,
# If there are additional child elements to serialize.
output.append(">")
if xml.text:
- output.append(xml_escape(xml.text))
+ output.append(escape(xml.text, use_cdata))
if len(xml):
for child in xml:
output.append(tostring(child, tag_xmlns, stanza_ns, stream))
output.append("</%s>" % tag_name)
elif xml.text:
# If we only have text content.
- output.append(">%s</%s>" % (xml_escape(xml.text), tag_name))
+ output.append(">%s</%s>" % (escape(xml.text, use_cdata), tag_name))
else:
# Empty element.
output.append(" />")
if xml.tail:
# If there is additional text after the element.
- output.append(xml_escape(xml.tail))
+ output.append(escape(xml.tail, use_cdata))
return ''.join(output)
-def xml_escape(text):
+def escape(text, use_cdata=False):
"""Convert special characters in XML to escape sequences.
:param string text: The XML text to convert.
@@ -132,12 +134,24 @@ def xml_escape(text):
if type(text) != types.UnicodeType:
text = unicode(text, 'utf-8', 'ignore')
- text = list(text)
escapes = {'&': '&amp;',
'<': '&lt;',
'>': '&gt;',
"'": '&apos;',
'"': '&quot;'}
- for i, c in enumerate(text):
- text[i] = escapes.get(c, c)
- return ''.join(text)
+
+ if not use_cdata:
+ text = list(text)
+ for i, c in enumerate(text):
+ text[i] = escapes.get(c, c)
+ return ''.join(text)
+ else:
+ escape_needed = False
+ for c in text:
+ if c in escapes:
+ escape_needed = True
+ break
+ if escape_needed:
+ escaped = map(lambda x : "<![CDATA[%s]]>" % x, text.split("]]>"))
+ return "<![CDATA[]]]><![CDATA[]>]]>".join(escaped)
+ return text