summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--sleekxmpp/xmlstream/tostring.py38
-rw-r--r--tests/test_tostring.py3
3 files changed, 11 insertions, 31 deletions
diff --git a/.gitignore b/.gitignore
index 602416e8..7c2b5bce 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,3 +12,4 @@ sleekxmpp.egg-info/
*~
.baboon/
.DS_STORE
+*.iml
diff --git a/sleekxmpp/xmlstream/tostring.py b/sleekxmpp/xmlstream/tostring.py
index c49abd3e..4d7976b1 100644
--- a/sleekxmpp/xmlstream/tostring.py
+++ b/sleekxmpp/xmlstream/tostring.py
@@ -16,6 +16,7 @@
from __future__ import unicode_literals
import sys
+from xml.etree.ElementTree import _escape_cdata, _escape_attrib
if sys.version_info < (3, 0):
import types
@@ -140,33 +141,12 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='',
def escape(text, use_cdata=False):
- """Convert special characters in XML to escape sequences.
+ encoding = 'utf-8'
- :param string text: The XML text to convert.
- :rtype: Unicode string
- """
- if sys.version_info < (3, 0):
- if type(text) != types.UnicodeType:
- text = unicode(text, 'utf-8', 'ignore')
-
- escapes = {'&': '&amp;',
- '<': '&lt;',
- '>': '&gt;',
- "'": '&apos;',
- '"': '&quot;'}
-
- if not use_cdata:
- text = list(text)
- for i, c in enumerate(text):
- text[i] = escapes.get(c, c)
- return ''.join(text)
- else:
- escape_needed = False
- for c in text:
- if c in escapes:
- escape_needed = True
- break
- if escape_needed:
- escaped = map(lambda x : "<![CDATA[%s]]>" % x, text.split("]]>"))
- return "<![CDATA[]]]><![CDATA[]>]]>".join(escaped)
- return text
+ if use_cdata:
+ return _escape_cdata(text, encoding)
+
+ text = _escape_attrib(text, encoding)
+ if "'" in text:
+ text = text.replace("'", "&apos;")
+ return text
diff --git a/tests/test_tostring.py b/tests/test_tostring.py
index e6148533..be11ab03 100644
--- a/tests/test_tostring.py
+++ b/tests/test_tostring.py
@@ -34,8 +34,7 @@ class TestToString(SleekTest):
desired = """&lt;foo bar=&quot;baz&quot;&gt;&apos;Hi"""
desired += """ &amp; welcome!&apos;&lt;/foo&gt;"""
- self.failUnless(escaped == desired,
- "XML escaping did not work: %s." % escaped)
+ self.assertEqual(escaped, desired)
def testEmptyElement(self):
"""Test converting an empty element to a string."""