summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--sleekxmpp/xmlstream/tostring.py38
-rw-r--r--tests/test_tostring.py3
3 files changed, 31 insertions, 11 deletions
diff --git a/.gitignore b/.gitignore
index 7c2b5bce..602416e8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,4 +12,3 @@ sleekxmpp.egg-info/
*~
.baboon/
.DS_STORE
-*.iml
diff --git a/sleekxmpp/xmlstream/tostring.py b/sleekxmpp/xmlstream/tostring.py
index 4d7976b1..c49abd3e 100644
--- a/sleekxmpp/xmlstream/tostring.py
+++ b/sleekxmpp/xmlstream/tostring.py
@@ -16,7 +16,6 @@
from __future__ import unicode_literals
import sys
-from xml.etree.ElementTree import _escape_cdata, _escape_attrib
if sys.version_info < (3, 0):
import types
@@ -141,12 +140,33 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='',
def escape(text, use_cdata=False):
- encoding = 'utf-8'
+ """Convert special characters in XML to escape sequences.
- if use_cdata:
- return _escape_cdata(text, encoding)
-
- text = _escape_attrib(text, encoding)
- if "'" in text:
- text = text.replace("'", "&apos;")
- return text
+ :param string text: The XML text to convert.
+ :rtype: Unicode string
+ """
+ if sys.version_info < (3, 0):
+ if type(text) != types.UnicodeType:
+ text = unicode(text, 'utf-8', 'ignore')
+
+ escapes = {'&': '&amp;',
+ '<': '&lt;',
+ '>': '&gt;',
+ "'": '&apos;',
+ '"': '&quot;'}
+
+ if not use_cdata:
+ text = list(text)
+ for i, c in enumerate(text):
+ text[i] = escapes.get(c, c)
+ return ''.join(text)
+ else:
+ escape_needed = False
+ for c in text:
+ if c in escapes:
+ escape_needed = True
+ break
+ if escape_needed:
+ escaped = map(lambda x : "<![CDATA[%s]]>" % x, text.split("]]>"))
+ return "<![CDATA[]]]><![CDATA[]>]]>".join(escaped)
+ return text
diff --git a/tests/test_tostring.py b/tests/test_tostring.py
index be11ab03..e6148533 100644
--- a/tests/test_tostring.py
+++ b/tests/test_tostring.py
@@ -34,7 +34,8 @@ class TestToString(SleekTest):
desired = """&lt;foo bar=&quot;baz&quot;&gt;&apos;Hi"""
desired += """ &amp; welcome!&apos;&lt;/foo&gt;"""
- self.assertEqual(escaped, desired)
+ self.failUnless(escaped == desired,
+ "XML escaping did not work: %s." % escaped)
def testEmptyElement(self):
"""Test converting an empty element to a string."""