diff options
author | Lance Stout <lancestout@gmail.com> | 2012-07-24 20:01:18 -0700 |
---|---|---|
committer | Lance Stout <lancestout@gmail.com> | 2012-07-24 20:01:18 -0700 |
commit | c42f1ad4c79863261977a9c5ea3b33be0b51b946 (patch) | |
tree | 8eee86ddb082f51dea0866f16146bcd1f4f13c1f | |
parent | a3ec1af2053bc0be4864ae290e6e5fc39f3fd5fe (diff) | |
parent | 9d8de7fc15afc39a666d2ac16b62a068dfc55112 (diff) | |
download | slixmpp-c42f1ad4c79863261977a9c5ea3b33be0b51b946.tar.gz slixmpp-c42f1ad4c79863261977a9c5ea3b33be0b51b946.tar.bz2 slixmpp-c42f1ad4c79863261977a9c5ea3b33be0b51b946.tar.xz slixmpp-c42f1ad4c79863261977a9c5ea3b33be0b51b946.zip |
Merge branch 'master' into develop
-rwxr-xr-x | setup.py | 1 | ||||
-rw-r--r-- | sleekxmpp/__init__.py | 1 | ||||
-rw-r--r-- | sleekxmpp/clientxmpp.py | 13 | ||||
-rw-r--r-- | sleekxmpp/jid.py | 541 | ||||
-rw-r--r-- | sleekxmpp/plugins/__init__.py | 1 | ||||
-rw-r--r-- | sleekxmpp/plugins/xep_0047/stream.py | 7 | ||||
-rw-r--r-- | sleekxmpp/plugins/xep_0084/avatar.py | 16 | ||||
-rw-r--r-- | sleekxmpp/plugins/xep_0084/stanza.py | 2 | ||||
-rw-r--r-- | sleekxmpp/plugins/xep_0106.py | 26 | ||||
-rw-r--r-- | sleekxmpp/plugins/xep_0153/vcard_avatar.py | 3 | ||||
-rw-r--r-- | sleekxmpp/stanza/error.py | 2 | ||||
-rw-r--r-- | sleekxmpp/test/livesocket.py | 10 | ||||
-rw-r--r-- | sleekxmpp/test/mocksocket.py | 10 | ||||
-rw-r--r-- | sleekxmpp/test/sleektest.py | 7 | ||||
-rw-r--r-- | sleekxmpp/util/__init__.py | 23 | ||||
-rw-r--r-- | sleekxmpp/util/stringprep_profiles.py | 119 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/__init__.py | 2 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/handler/waiter.py | 9 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/jid.py | 149 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/scheduler.py | 10 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/tostring.py | 32 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 17 | ||||
-rw-r--r-- | tests/test_jid.py | 143 | ||||
-rw-r--r-- | tests/test_tostring.py | 4 |
24 files changed, 936 insertions, 212 deletions
@@ -49,6 +49,7 @@ packages = [ 'sleekxmpp', 'sleekxmpp/stanza', 'sleekxmpp/test', 'sleekxmpp/roster', + 'sleekxmpp/util', 'sleekxmpp/xmlstream', 'sleekxmpp/xmlstream/matcher', 'sleekxmpp/xmlstream/handler', diff --git a/sleekxmpp/__init__.py b/sleekxmpp/__init__.py index a1f1c0f1..f0dc2ce2 100644 --- a/sleekxmpp/__init__.py +++ b/sleekxmpp/__init__.py @@ -10,6 +10,7 @@ from sleekxmpp.basexmpp import BaseXMPP from sleekxmpp.clientxmpp import ClientXMPP from sleekxmpp.componentxmpp import ComponentXMPP from sleekxmpp.stanza import Message, Presence, Iq +from sleekxmpp.jid import JID, InvalidJID from sleekxmpp.xmlstream.handler import * from sleekxmpp.xmlstream import XMLStream, RestartStream from sleekxmpp.xmlstream.matcher import * diff --git a/sleekxmpp/clientxmpp.py b/sleekxmpp/clientxmpp.py index 48637dad..e3b434e9 100644 --- a/sleekxmpp/clientxmpp.py +++ b/sleekxmpp/clientxmpp.py @@ -179,8 +179,7 @@ class ClientXMPP(BaseXMPP): self._stream_feature_order.remove((order, name)) self._stream_feature_order.sort() - def update_roster(self, jid, name=None, subscription=None, groups=[], - block=True, timeout=None, callback=None): + def update_roster(self, jid, **kwargs): """Add or change a roster item. :param jid: The JID of the entry to modify. @@ -201,6 +200,16 @@ class ClientXMPP(BaseXMPP): Will be executed when the roster is received. Implies ``block=False``. """ + current = self.client_roster[jid] + + name = kwargs.get('name', current['name']) + subscription = kwargs.get('subscription', current['subscription']) + groups = kwargs.get('groups', current['groups']) + + block = kwargs.get('block', True) + timeout = kwargs.get('timeout', None) + callback = kwargs.get('callback', None) + return self.client_roster.update(jid, name, subscription, groups, block, timeout, callback) diff --git a/sleekxmpp/jid.py b/sleekxmpp/jid.py new file mode 100644 index 00000000..9e9c0d0b --- /dev/null +++ b/sleekxmpp/jid.py @@ -0,0 +1,541 @@ +# -*- coding: utf-8 -*- +""" + sleekxmpp.jid + ~~~~~~~~~~~~~~~~~~~~~~~ + + This module allows for working with Jabber IDs (JIDs). + + Part of SleekXMPP: The Sleek XMPP Library + + :copyright: (c) 2011 Nathanael C. Fritz + :license: MIT, see LICENSE for more details +""" + +from __future__ import unicode_literals + +import re +import socket +import stringprep +import encodings.idna + +from sleekxmpp.util import stringprep_profiles + +#: These characters are not allowed to appear in a JID. +ILLEGAL_CHARS = '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r' + \ + '\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19' + \ + '\x1a\x1b\x1c\x1d\x1e\x1f' + \ + ' !"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~\x7f' + +#: The basic regex pattern that a JID must match in order to determine +#: the local, domain, and resource parts. This regex does NOT do any +#: validation, which requires application of nodeprep, resourceprep, etc. +JID_PATTERN = "^(?:([^\"&'/:<>@]{1,1023})@)?([^/@]{1,1023})(?:/(.{1,1023}))?$" + +#: The set of escape sequences for the characters not allowed by nodeprep. +JID_ESCAPE_SEQUENCES = set(['\\20', '\\22', '\\26', '\\27', '\\2f', + '\\3a', '\\3c', '\\3e', '\\40', '\\5c']) + +#: A mapping of unallowed characters to their escape sequences. An escape +#: sequence for '\' is also included since it must also be escaped in +#: certain situations. +JID_ESCAPE_TRANSFORMATIONS = {' ': '\\20', + '"': '\\22', + '&': '\\26', + "'": '\\27', + '/': '\\2f', + ':': '\\3a', + '<': '\\3c', + '>': '\\3e', + '@': '\\40', + '\\': '\\5c'} + +#: The reverse mapping of escape sequences to their original forms. +JID_UNESCAPE_TRANSFORMATIONS = {'\\20': ' ', + '\\22': '"', + '\\26': '&', + '\\27': "'", + '\\2f': '/', + '\\3a': ':', + '\\3c': '<', + '\\3e': '>', + '\\40': '@', + '\\5c': '\\'} + + +# pylint: disable=c0103 +#: The nodeprep profile of stringprep used to validate the local, +#: or username, portion of a JID. +nodeprep = stringprep_profiles.create( + nfkc=True, + bidi=True, + mappings=[ + stringprep_profiles.b1_mapping, + stringprep_profiles.c12_mapping], + prohibited=[ + stringprep.in_table_c11, + stringprep.in_table_c12, + stringprep.in_table_c21, + stringprep.in_table_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9, + lambda c: c in ' \'"&/:<>@'], + unassigned=[stringprep.in_table_a1]) + +# pylint: disable=c0103 +#: The resourceprep profile of stringprep, which is used to validate +#: the resource portion of a JID. +resourceprep = stringprep_profiles.create( + nfkc=True, + bidi=True, + mappings=[stringprep_profiles.b1_mapping], + prohibited=[ + stringprep.in_table_c12, + stringprep.in_table_c21, + stringprep.in_table_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9], + unassigned=[stringprep.in_table_a1]) + + +def _parse_jid(data): + """ + Parse string data into the node, domain, and resource + components of a JID, if possible. + + :param string data: A string that is potentially a JID. + + :raises InvalidJID: + + :returns: tuple of the validated local, domain, and resource strings + """ + match = re.match(JID_PATTERN, data) + if not match: + raise InvalidJID('JID could not be parsed') + + (node, domain, resource) = match.groups() + + node = _validate_node(node) + domain = _validate_domain(domain) + resource = _validate_resource(resource) + + return node, domain, resource + + +def _validate_node(node): + """Validate the local, or username, portion of a JID. + + :raises InvalidJID: + + :returns: The local portion of a JID, as validated by nodeprep. + """ + try: + if node is not None: + node = nodeprep(node) + + if not node: + raise InvalidJID('Localpart must not be 0 bytes') + if len(node) > 1023: + raise InvalidJID('Localpart must be less than 1024 bytes') + return node + except stringprep_profiles.StringPrepError: + raise InvalidJID('Invalid local part') + + +def _validate_domain(domain): + """Validate the domain portion of a JID. + + IP literal addresses are left as-is, if valid. Domain names + are stripped of any trailing label separators (`.`), and are + checked with the nameprep profile of stringprep. If the given + domain is actually a punyencoded version of a domain name, it + is converted back into its original Unicode form. Domains must + also not start or end with a dash (`-`). + + :raises InvalidJID: + + :returns: The validated domain name + """ + ip_addr = False + + # First, check if this is an IPv4 address + try: + socket.inet_aton(domain) + ip_addr = True + except socket.error: + pass + + # Check if this is an IPv6 address + if not ip_addr and hasattr(socket, 'inet_pton'): + try: + socket.inet_pton(socket.AF_INET6, domain.strip('[]')) + domain = '[%s]' % domain.strip('[]') + ip_addr = True + except socket.error: + pass + + if not ip_addr: + # This is a domain name, which must be checked further + + if domain and domain[-1] == '.': + domain = domain[:-1] + + domain_parts = [] + for label in domain.split('.'): + try: + label = encodings.idna.nameprep(label) + encodings.idna.ToASCII(label) + pass_nameprep = True + except UnicodeError: + pass_nameprep = False + + if not pass_nameprep: + raise InvalidJID('Could not encode domain as ASCII') + + if label.startswith('xn--'): + label = encodings.idna.ToUnicode(label) + + for char in label: + if char in ILLEGAL_CHARS: + raise InvalidJID('Domain contains illegar characters') + + if '-' in (label[0], label[-1]): + raise InvalidJID('Domain started or ended with -') + + domain_parts.append(label) + domain = '.'.join(domain_parts) + + if not domain: + raise InvalidJID('Domain must not be 0 bytes') + if len(domain) > 1023: + raise InvalidJID('Domain must be less than 1024 bytes') + + return domain + + +def _validate_resource(resource): + """Validate the resource portion of a JID. + + :raises InvalidJID: + + :returns: The local portion of a JID, as validated by resourceprep. + """ + try: + if resource is not None: + resource = resourceprep(resource) + + if not resource: + raise InvalidJID('Resource must not be 0 bytes') + if len(resource) > 1023: + raise InvalidJID('Resource must be less than 1024 bytes') + return resource + except stringprep_profiles.StringPrepError: + raise InvalidJID('Invalid resource') + + +def _escape_node(node): + """Escape the local portion of a JID.""" + result = [] + + for i, char in enumerate(node): + if char == '\\': + if ''.join((node[i:i+3])) in JID_ESCAPE_SEQUENCES: + result.append('\\5c') + continue + result.append(char) + + for i, char in enumerate(result): + if char != '\\': + result[i] = JID_ESCAPE_TRANSFORMATIONS.get(char, char) + + escaped = ''.join(result) + + if escaped.startswith('\\20') or escaped.endswith('\\20'): + raise InvalidJID('Escaped local part starts or ends with "\\20"') + + _validate_node(escaped) + + return escaped + + +def _unescape_node(node): + """Unescape a local portion of a JID. + + .. note:: + The unescaped local portion is meant ONLY for presentation, + and should not be used for other purposes. + """ + unescaped = [] + seq = '' + for i, char in enumerate(node): + if char == '\\': + seq = node[i:i+3] + if seq not in JID_ESCAPE_SEQUENCES: + seq = '' + if seq: + if len(seq) == 3: + unescaped.append(JID_UNESCAPE_TRANSFORMATIONS.get(seq, char)) + + # Pop character off the escape sequence, and ignore it + seq = seq[1:] + else: + unescaped.append(char) + unescaped = ''.join(unescaped) + + return unescaped + + +def _format_jid(local=None, domain=None, resource=None): + """Format the given JID components into a full or bare JID. + + :param string local: Optional. The local portion of the JID. + :param string domain: Required. The domain name portion of the JID. + :param strin resource: Optional. The resource portion of the JID. + + :return: A full or bare JID string. + """ + result = [] + if local: + result.append(local) + result.append('@') + if domain: + result.append(domain) + if resource: + result.append('/') + result.append(resource) + return ''.join(result) + + +class InvalidJID(ValueError): + """ + Raised when attempting to create a JID that does not pass validation. + + It can also be raised if modifying an existing JID in such a way as + to make it invalid, such trying to remove the domain from an existing + full JID while the local and resource portions still exist. + """ + +# pylint: disable=R0903 +class UnescapedJID(object): + + """ + .. versionadded:: 1.1.10 + """ + + def __init__(self, local, domain, resource): + self._jid = (local, domain, resource) + + # pylint: disable=R0911 + def __getattr__(self, name): + """Retrieve the given JID component. + + :param name: one of: user, server, domain, resource, + full, or bare. + """ + if name == 'resource': + return self._jid[2] or '' + elif name in ('user', 'username', 'local', 'node'): + return self._jid[0] or '' + elif name in ('server', 'domain', 'host'): + return self._jid[1] or '' + elif name in ('full', 'jid'): + return _format_jid(*self._jid) + elif name == 'bare': + return _format_jid(self._jid[0], self._jid[1]) + elif name == '_jid': + return getattr(super(JID, self), '_jid') + else: + return None + + def __str__(self): + """Use the full JID as the string value.""" + return _format_jid(*self._jid) + + def __repr__(self): + """Use the full JID as the representation.""" + return self.__str__() + + +class JID(object): + + """ + A representation of a Jabber ID, or JID. + + Each JID may have three components: a user, a domain, and an optional + resource. For example: user@domain/resource + + When a resource is not used, the JID is called a bare JID. + The JID is a full JID otherwise. + + **JID Properties:** + :jid: Alias for ``full``. + :full: The string value of the full JID. + :bare: The string value of the bare JID. + :user: The username portion of the JID. + :username: Alias for ``user``. + :local: Alias for ``user``. + :node: Alias for ``user``. + :domain: The domain name portion of the JID. + :server: Alias for ``domain``. + :host: Alias for ``domain``. + :resource: The resource portion of the JID. + + :param string jid: + A string of the form ``'[user@]domain[/resource]'``. + :param string local: + Optional. Specify the local, or username, portion + of the JID. If provided, it will override the local + value provided by the `jid` parameter. The given + local value will also be escaped if necessary. + :param string domain: + Optional. Specify the domain of the JID. If + provided, it will override the domain given by + the `jid` parameter. + :param string resource: + Optional. Specify the resource value of the JID. + If provided, it will override the domain given + by the `jid` parameter. + + :raises InvalidJID: + """ + + # pylint: disable=W0212 + def __init__(self, jid=None, **kwargs): + self._jid = (None, None, None) + + if jid is None or jid == '': + jid = (None, None, None) + elif not isinstance(jid, JID): + jid = _parse_jid(jid) + else: + jid = jid._jid + + local, domain, resource = jid + + local = kwargs.get('local', local) + domain = kwargs.get('domain', domain) + resource = kwargs.get('resource', resource) + + if 'local' in kwargs: + local = _escape_node(local) + if 'domain' in kwargs: + domain = _validate_domain(domain) + if 'resource' in kwargs: + resource = _validate_resource(resource) + + self._jid = (local, domain, resource) + + def unescape(self): + """Return an unescaped JID object. + + Using an unescaped JID is preferred for displaying JIDs + to humans, and they should NOT be used for any other + purposes than for presentation. + + :return: :class:`UnescapedJID` + + .. versionadded:: 1.1.10 + """ + return UnescapedJID(_unescape_node(self._jid[0]), + self._jid[1], + self._jid[2]) + + def regenerate(self): + """No-op + + .. deprecated:: 1.1.10 + """ + pass + + def reset(self, data): + """Start fresh from a new JID string. + + :param string data: A string of the form ``'[user@]domain[/resource]'``. + + .. deprecated:: 1.1.10 + """ + self._jid = JID(data)._jid + + # pylint: disable=R0911 + def __getattr__(self, name): + """Retrieve the given JID component. + + :param name: one of: user, server, domain, resource, + full, or bare. + """ + if name == 'resource': + return self._jid[2] or '' + elif name in ('user', 'username', 'local', 'node'): + return self._jid[0] or '' + elif name in ('server', 'domain', 'host'): + return self._jid[1] or '' + elif name in ('full', 'jid'): + return _format_jid(*self._jid) + elif name == 'bare': + return _format_jid(self._jid[0], self._jid[1]) + elif name == '_jid': + return getattr(super(JID, self), '_jid') + else: + return None + + # pylint: disable=W0212 + def __setattr__(self, name, value): + """Update the given JID component. + + :param name: one of: ``user``, ``username``, ``local``, + ``node``, ``server``, ``domain``, ``host``, + ``resource``, ``full``, ``jid``, or ``bare``. + :param value: The new string value of the JID component. + """ + if name == 'resource': + self._jid = JID(self, resource=value)._jid + elif name in ('user', 'username', 'local', 'node'): + self._jid = JID(self, local=value)._jid + elif name in ('server', 'domain', 'host'): + self._jid = JID(self, domain=value)._jid + elif name in ('full', 'jid'): + self._jid = JID(value)._jid + elif name == 'bare': + parsed = JID(value)._jid + self._jid = (parsed[0], parsed[1], self._jid[2]) + elif name == '_jid': + super(JID, self).__setattr__('_jid', value) + + def __str__(self): + """Use the full JID as the string value.""" + return _format_jid(*self._jid) + + def __repr__(self): + """Use the full JID as the representation.""" + return self.__str__() + + # pylint: disable=W0212 + def __eq__(self, other): + """Two JIDs are equal if they have the same full JID value.""" + if isinstance(other, UnescapedJID): + return False + + other = JID(other) + return self._jid == other._jid + + # pylint: disable=W0212 + def __ne__(self, other): + """Two JIDs are considered unequal if they are not equal.""" + return not self == other + + def __hash__(self): + """Hash a JID based on the string version of its full JID.""" + return hash(self.__str__()) + + def __copy__(self): + """Generate a duplicate JID.""" + return JID(self) diff --git a/sleekxmpp/plugins/__init__.py b/sleekxmpp/plugins/__init__.py index dbab2d1c..270626ed 100644 --- a/sleekxmpp/plugins/__init__.py +++ b/sleekxmpp/plugins/__init__.py @@ -37,6 +37,7 @@ __all__ = [ 'xep_0085', # Chat State Notifications 'xep_0086', # Legacy Error Codes 'xep_0092', # Software Version + 'xep_0106', # JID Escaping 'xep_0107', # User Mood 'xep_0108', # User Activity 'xep_0115', # Entity Capabilities diff --git a/sleekxmpp/plugins/xep_0047/stream.py b/sleekxmpp/plugins/xep_0047/stream.py index 49f56f36..b49a077b 100644 --- a/sleekxmpp/plugins/xep_0047/stream.py +++ b/sleekxmpp/plugins/xep_0047/stream.py @@ -1,11 +1,8 @@ import socket import threading import logging -try: - import queue -except ImportError: - import Queue as queue +from sleekxmpp.util import Queue from sleekxmpp.exceptions import XMPPError @@ -33,7 +30,7 @@ class IBBytestream(object): self.stream_in_closed = threading.Event() self.stream_out_closed = threading.Event() - self.recv_queue = queue.Queue() + self.recv_queue = Queue() self.send_window = threading.BoundedSemaphore(value=self.window_size) self.window_ids = set() diff --git a/sleekxmpp/plugins/xep_0084/avatar.py b/sleekxmpp/plugins/xep_0084/avatar.py index bbac330a..03711871 100644 --- a/sleekxmpp/plugins/xep_0084/avatar.py +++ b/sleekxmpp/plugins/xep_0084/avatar.py @@ -41,6 +41,9 @@ class XEP_0084(BasePlugin): def session_bind(self, jid): self.xmpp['xep_0163'].register_pep('avatar_metadata', MetaData) + def generate_id(self, data): + return hashlib.sha1(data).hexdigest() + def retrieve_avatar(self, jid, id, url=None, ifrom=None, block=True, callback=None, timeout=None): return self.xmpp['xep_0060'].get_item(jid, Data.namespace, id, @@ -54,8 +57,7 @@ class XEP_0084(BasePlugin): payload = Data() payload['value'] = data return self.xmpp['xep_0163'].publish(payload, - node=Data.namespace, - id=hashlib.sha1(data).hexdigest(), + id=self.generate_id(data), ifrom=ifrom, block=block, callback=callback, @@ -72,12 +74,12 @@ class XEP_0084(BasePlugin): height=info.get('height', ''), width=info.get('width', ''), url=info.get('url', '')) - for pointer in pointers: - metadata.add_pointer(pointer) - return self.xmpp['xep_0163'].publish(payload, - node=Data.namespace, - id=hashlib.sha1(data).hexdigest(), + if pointers is not None: + for pointer in pointers: + metadata.add_pointer(pointer) + + return self.xmpp['xep_0163'].publish(metadata, ifrom=ifrom, block=block, callback=callback, diff --git a/sleekxmpp/plugins/xep_0084/stanza.py b/sleekxmpp/plugins/xep_0084/stanza.py index 1b204471..e9133998 100644 --- a/sleekxmpp/plugins/xep_0084/stanza.py +++ b/sleekxmpp/plugins/xep_0084/stanza.py @@ -43,7 +43,7 @@ class MetaData(ElementBase): info = Info() info.values = {'id': id, 'type': itype, - 'bytes': ibytes, + 'bytes': '%s' % ibytes, 'height': height, 'width': width, 'url': url} diff --git a/sleekxmpp/plugins/xep_0106.py b/sleekxmpp/plugins/xep_0106.py new file mode 100644 index 00000000..1859a77b --- /dev/null +++ b/sleekxmpp/plugins/xep_0106.py @@ -0,0 +1,26 @@ +""" + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2012 Nathanael C. Fritz, Lance J.T. Stout + This file is part of SleekXMPP. + + See the file LICENSE for copying permission. +""" + + +from sleekxmpp.plugins import BasePlugin, register_plugin + + +class XEP_0106(BasePlugin): + + name = 'xep_0106' + description = 'XEP-0106: JID Escaping' + dependencies = set(['xep_0030']) + + def session_bind(self, jid): + self.xmpp['xep_0030'].add_feature(feature='jid\\20escaping') + + def plugin_end(self): + self.xmpp['xep_0030'].del_feature(feature='jid\\20escaping') + + +register_plugin(XEP_0106) diff --git a/sleekxmpp/plugins/xep_0153/vcard_avatar.py b/sleekxmpp/plugins/xep_0153/vcard_avatar.py index 6b70e33e..bec792cb 100644 --- a/sleekxmpp/plugins/xep_0153/vcard_avatar.py +++ b/sleekxmpp/plugins/xep_0153/vcard_avatar.py @@ -75,6 +75,9 @@ class XEP_0153(BasePlugin): return stanza def _reset_hash(self, jid=None): + if jid is None: + jid = self.xmpp.boundjid + own_jid = (jid.bare == self.xmpp.boundjid.bare) if self.xmpp.is_component: own_jid = (jid.domain == self.xmpp.boundjid.domain) diff --git a/sleekxmpp/stanza/error.py b/sleekxmpp/stanza/error.py index 60bc65bc..56558ba8 100644 --- a/sleekxmpp/stanza/error.py +++ b/sleekxmpp/stanza/error.py @@ -52,7 +52,7 @@ class Error(ElementBase): name = 'error' plugin_attrib = 'error' interfaces = set(('code', 'condition', 'text', 'type', - 'gone', 'redirect')) + 'gone', 'redirect', 'by')) sub_interfaces = set(('text',)) plugin_attrib_map = {} plugin_tag_map = {} diff --git a/sleekxmpp/test/livesocket.py b/sleekxmpp/test/livesocket.py index 80d63307..d70ee4eb 100644 --- a/sleekxmpp/test/livesocket.py +++ b/sleekxmpp/test/livesocket.py @@ -8,10 +8,8 @@ import socket import threading -try: - import queue -except ImportError: - import Queue as queue + +from sleekxmpp.util import Queue class TestLiveSocket(object): @@ -39,8 +37,8 @@ class TestLiveSocket(object): """ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.recv_buffer = [] - self.recv_queue = queue.Queue() - self.send_queue = queue.Queue() + self.recv_queue = Queue() + self.send_queue = Queue() self.send_queue_lock = threading.Lock() self.recv_queue_lock = threading.Lock() self.is_live = True diff --git a/sleekxmpp/test/mocksocket.py b/sleekxmpp/test/mocksocket.py index 0920b7ea..4c9d1699 100644 --- a/sleekxmpp/test/mocksocket.py +++ b/sleekxmpp/test/mocksocket.py @@ -7,10 +7,8 @@ """ import socket -try: - import queue -except ImportError: - import Queue as queue + +from sleekxmpp.util import Queue class TestSocket(object): @@ -36,8 +34,8 @@ class TestSocket(object): Same as arguments for socket.socket """ self.socket = socket.socket(*args, **kwargs) - self.recv_queue = queue.Queue() - self.send_queue = queue.Queue() + self.recv_queue = Queue() + self.send_queue = Queue() self.is_live = False self.disconnected = False diff --git a/sleekxmpp/test/sleektest.py b/sleekxmpp/test/sleektest.py index cac99f77..47af86cf 100644 --- a/sleekxmpp/test/sleektest.py +++ b/sleekxmpp/test/sleektest.py @@ -8,13 +8,10 @@ import unittest from xml.parsers.expat import ExpatError -try: - import Queue as queue -except: - import queue import sleekxmpp from sleekxmpp import ClientXMPP, ComponentXMPP +from sleekxmpp.util import Queue from sleekxmpp.stanza import Message, Iq, Presence from sleekxmpp.test import TestSocket, TestLiveSocket from sleekxmpp.exceptions import XMPPError, IqTimeout, IqError @@ -338,7 +335,7 @@ class SleekTest(unittest.TestCase): # We will use this to wait for the session_start event # for live connections. - skip_queue = queue.Queue() + skip_queue = Queue() if socket == 'mock': self.xmpp.set_socket(TestSocket()) diff --git a/sleekxmpp/util/__init__.py b/sleekxmpp/util/__init__.py new file mode 100644 index 00000000..86a87222 --- /dev/null +++ b/sleekxmpp/util/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +""" + sleekxmpp.util + ~~~~~~~~~~~~~~ + + Part of SleekXMPP: The Sleek XMPP Library + + :copyright: (c) 2012 Nathanael C. Fritz, Lance J.T. Stout + :license: MIT, see LICENSE for more details +""" + + +# ===================================================================== +# Standardize import of Queue class: + +try: + import queue +except ImportError: + import Queue as queue + + +Queue = queue.Queue +QueueEmpty = queue.Empty diff --git a/sleekxmpp/util/stringprep_profiles.py b/sleekxmpp/util/stringprep_profiles.py new file mode 100644 index 00000000..6844c9ac --- /dev/null +++ b/sleekxmpp/util/stringprep_profiles.py @@ -0,0 +1,119 @@ +from __future__ import unicode_literals + +import sys +import stringprep +import unicodedata + + +class StringPrepError(UnicodeError): + pass + + +def to_unicode(data): + if sys.version_info < (3, 0): + return unicode(data) + else: + return str(data) + + +def b1_mapping(char): + return '' if stringprep.in_table_c12(char) else None + + +def c12_mapping(char): + return ' ' if stringprep.in_table_c12(char) else None + + +def map_input(data, tables=None): + """ + Each character in the input stream MUST be checked against + a mapping table. + """ + result = [] + for char in data: + replacement = None + + for mapping in tables: + replacement = mapping(char) + if replacement is not None: + break + + if replacement is None: + replacement = char + result.append(replacement) + return ''.join(result) + + +def normalize(data, nfkc=True): + """ + A profile can specify one of two options for Unicode normalization: + - no normalization + - Unicode normalization with form KC + """ + if nfkc: + data = unicodedata.normalize('NFKC', data) + return data + + +def prohibit_output(data, tables=None): + """ + Before the text can be emitted, it MUST be checked for prohibited + code points. + """ + for char in data: + for check in tables: + if check(char): + raise StringPrepError("Prohibited code point: %s" % char) + + +def check_bidi(data): + """ + 1) The characters in section 5.8 MUST be prohibited. + + 2) If a string contains any RandALCat character, the string MUST NOT + contain any LCat character. + + 3) If a string contains any RandALCat character, a RandALCat + character MUST be the first character of the string, and a + RandALCat character MUST be the last character of the string. + """ + if not data: + return data + + has_lcat = False + has_randal = False + + for c in data: + if stringprep.in_table_c8(c): + raise StringPrepError("BIDI violation: seciton 6 (1)") + if stringprep.in_table_d1(c): + has_randal = True + elif stringprep.in_table_d2(c): + has_lcat = True + + if has_randal and has_lcat: + raise StringPrepError("BIDI violation: section 6 (2)") + + first_randal = stringprep.in_table_d1(data[0]) + last_randal = stringprep.in_table_d1(data[-1]) + if has_randal and not (first_randal and last_randal): + raise StringPrepError("BIDI violation: section 6 (3)") + + +def create(nfkc=True, bidi=True, mappings=None, + prohibited=None, unassigned=None): + def profile(data, query=False): + try: + data = to_unicode(data) + except UnicodeError: + raise StringPrepError + + data = map_input(data, mappings) + data = normalize(data, nfkc) + prohibit_output(data, prohibited) + if bidi: + check_bidi(data) + if query and unassigned: + check_unassigned(data, unassigned) + return data + return profile diff --git a/sleekxmpp/xmlstream/__init__.py b/sleekxmpp/xmlstream/__init__.py index 67b20c56..5a1ea1be 100644 --- a/sleekxmpp/xmlstream/__init__.py +++ b/sleekxmpp/xmlstream/__init__.py @@ -6,7 +6,7 @@ See the file LICENSE for copying permission. """ -from sleekxmpp.xmlstream.jid import JID +from sleekxmpp.jid import JID from sleekxmpp.xmlstream.scheduler import Scheduler from sleekxmpp.xmlstream.stanzabase import StanzaBase, ElementBase, ET from sleekxmpp.xmlstream.stanzabase import register_stanza_plugin diff --git a/sleekxmpp/xmlstream/handler/waiter.py b/sleekxmpp/xmlstream/handler/waiter.py index 899df17c..66e14496 100644 --- a/sleekxmpp/xmlstream/handler/waiter.py +++ b/sleekxmpp/xmlstream/handler/waiter.py @@ -10,11 +10,8 @@ """ import logging -try: - import queue -except ImportError: - import Queue as queue +from sleekxmpp.util import Queue, QueueEmpty from sleekxmpp.xmlstream.handler.base import BaseHandler @@ -37,7 +34,7 @@ class Waiter(BaseHandler): def __init__(self, name, matcher, stream=None): BaseHandler.__init__(self, name, matcher, stream=stream) - self._payload = queue.Queue() + self._payload = Queue() def prerun(self, payload): """Store the matched stanza when received during processing. @@ -74,7 +71,7 @@ class Waiter(BaseHandler): try: stanza = self._payload.get(True, 1) break - except queue.Empty: + except QueueEmpty: elapsed_time += 1 if elapsed_time >= timeout: log.warning("Timed out waiting for %s", self.name) diff --git a/sleekxmpp/xmlstream/jid.py b/sleekxmpp/xmlstream/jid.py index 1582164a..2b59db47 100644 --- a/sleekxmpp/xmlstream/jid.py +++ b/sleekxmpp/xmlstream/jid.py @@ -1,148 +1,5 @@ -# -*- coding: utf-8 -*- -""" - sleekxmpp.xmlstream.jid - ~~~~~~~~~~~~~~~~~~~~~~~ +import logging - This module allows for working with Jabber IDs (JIDs) by - providing accessors for the various components of a JID. +logging.warning('Deprecated: sleekxmpp.xmlstream.jid is moving to sleekxmpp.jid') - Part of SleekXMPP: The Sleek XMPP Library - - :copyright: (c) 2011 Nathanael C. Fritz - :license: MIT, see LICENSE for more details -""" - -from __future__ import unicode_literals - - -class JID(object): - - """ - A representation of a Jabber ID, or JID. - - Each JID may have three components: a user, a domain, and an optional - resource. For example: user@domain/resource - - When a resource is not used, the JID is called a bare JID. - The JID is a full JID otherwise. - - **JID Properties:** - :jid: Alias for ``full``. - :full: The value of the full JID. - :bare: The value of the bare JID. - :user: The username portion of the JID. - :domain: The domain name portion of the JID. - :server: Alias for ``domain``. - :resource: The resource portion of the JID. - - :param string jid: A string of the form ``'[user@]domain[/resource]'``. - """ - - def __init__(self, jid): - """Initialize a new JID""" - self.reset(jid) - - def reset(self, jid): - """Start fresh from a new JID string. - - :param string jid: A string of the form ``'[user@]domain[/resource]'``. - """ - if isinstance(jid, JID): - jid = jid.full - self._full = self._jid = jid - self._domain = None - self._resource = None - self._user = None - self._bare = None - - def __getattr__(self, name): - """Handle getting the JID values, using cache if available. - - :param name: One of: user, server, domain, resource, - full, or bare. - """ - if name == 'resource': - if self._resource is None and '/' in self._jid: - self._resource = self._jid.split('/', 1)[-1] - return self._resource or "" - elif name == 'user': - if self._user is None: - if '@' in self._jid: - self._user = self._jid.split('@', 1)[0] - else: - self._user = self._user - return self._user or "" - elif name in ('server', 'domain', 'host'): - if self._domain is None: - self._domain = self._jid.split('@', 1)[-1].split('/', 1)[0] - return self._domain or "" - elif name in ('full', 'jid'): - return self._jid or "" - elif name == 'bare': - if self._bare is None: - self._bare = self._jid.split('/', 1)[0] - return self._bare or "" - - def __setattr__(self, name, value): - """Edit a JID by updating it's individual values, resetting the - generated JID in the end. - - Arguments: - name -- The name of the JID part. One of: user, domain, - server, resource, full, jid, or bare. - value -- The new value for the JID part. - """ - if name in ('resource', 'user', 'domain'): - object.__setattr__(self, "_%s" % name, value) - self.regenerate() - elif name in ('server', 'domain', 'host'): - self.domain = value - elif name in ('full', 'jid'): - self.reset(value) - self.regenerate() - elif name == 'bare': - if '@' in value: - u, d = value.split('@', 1) - object.__setattr__(self, "_user", u) - object.__setattr__(self, "_domain", d) - else: - object.__setattr__(self, "_user", '') - object.__setattr__(self, "_domain", value) - self.regenerate() - else: - object.__setattr__(self, name, value) - - def regenerate(self): - """Generate a new JID based on current values, useful after editing.""" - jid = "" - if self.user: - jid = "%s@" % self.user - jid += self.domain - if self.resource: - jid += "/%s" % self.resource - self.reset(jid) - - def __str__(self): - """Use the full JID as the string value.""" - return self.full - - def __repr__(self): - return self.full - - def __eq__(self, other): - """ - Two JIDs are considered equal if they have the same full JID value. - """ - other = JID(other) - return self.full == other.full - - def __ne__(self, other): - """Two JIDs are considered unequal if they are not equal.""" - return not self == other - - def __hash__(self): - """Hash a JID based on the string version of its full JID.""" - return hash(self.full) - - def __copy__(self): - return JID(self.jid) +from sleekxmpp.jid import JID diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py index f68af081..d98dc6c8 100644 --- a/sleekxmpp/xmlstream/scheduler.py +++ b/sleekxmpp/xmlstream/scheduler.py @@ -15,10 +15,8 @@ import time import threading import logging -try: - import queue -except ImportError: - import Queue as queue + +from sleekxmpp.util import Queue, QueueEmpty log = logging.getLogger(__name__) @@ -102,7 +100,7 @@ class Scheduler(object): def __init__(self, parentstop=None): #: A queue for storing tasks - self.addq = queue.Queue() + self.addq = Queue() #: A list of tasks in order of execution time. self.schedule = [] @@ -157,7 +155,7 @@ class Scheduler(object): elapsed < wait: newtask = self.addq.get(True, 0.1) elapsed += 0.1 - except queue.Empty: + except QueueEmpty: cleanup = [] self.schedule_lock.acquire() for task in self.schedule: diff --git a/sleekxmpp/xmlstream/tostring.py b/sleekxmpp/xmlstream/tostring.py index 2480f9b2..f22e7770 100644 --- a/sleekxmpp/xmlstream/tostring.py +++ b/sleekxmpp/xmlstream/tostring.py @@ -63,9 +63,11 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, default_ns = '' stream_ns = '' + use_cdata = False if stream: default_ns = stream.default_ns stream_ns = stream.stream_ns + use_cdata = stream.use_cdata # Output the tag name and derived namespace of the element. namespace = '' @@ -81,7 +83,7 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, # Output escaped attribute values. for attrib, value in xml.attrib.items(): - value = xml_escape(value) + value = escape(value, use_cdata) if '}' not in attrib: output.append(' %s="%s"' % (attrib, value)) else: @@ -105,24 +107,24 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, # If there are additional child elements to serialize. output.append(">") if xml.text: - output.append(xml_escape(xml.text)) + output.append(escape(xml.text, use_cdata)) if len(xml): for child in xml: output.append(tostring(child, tag_xmlns, stanza_ns, stream)) output.append("</%s>" % tag_name) elif xml.text: # If we only have text content. - output.append(">%s</%s>" % (xml_escape(xml.text), tag_name)) + output.append(">%s</%s>" % (escape(xml.text, use_cdata), tag_name)) else: # Empty element. output.append(" />") if xml.tail: # If there is additional text after the element. - output.append(xml_escape(xml.tail)) + output.append(escape(xml.tail, use_cdata)) return ''.join(output) -def xml_escape(text): +def escape(text, use_cdata=False): """Convert special characters in XML to escape sequences. :param string text: The XML text to convert. @@ -132,12 +134,24 @@ def xml_escape(text): if type(text) != types.UnicodeType: text = unicode(text, 'utf-8', 'ignore') - text = list(text) escapes = {'&': '&', '<': '<', '>': '>', "'": ''', '"': '"'} - for i, c in enumerate(text): - text[i] = escapes.get(c, c) - return ''.join(text) + + if not use_cdata: + text = list(text) + for i, c in enumerate(text): + text[i] = escapes.get(c, c) + return ''.join(text) + else: + escape_needed = False + for c in text: + if c in escapes: + escape_needed = True + break + if escape_needed: + escaped = map(lambda x : "<![CDATA[%s]]>" % x, text.split("]]>")) + return "<![CDATA[]]]><![CDATA[]>]]>".join(escaped) + return text diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index 49f33933..a0b6e4c2 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -26,14 +26,11 @@ import time import random import weakref import uuid -try: - import queue -except ImportError: - import Queue as queue from xml.parsers.expat import ExpatError import sleekxmpp +from sleekxmpp.util import Queue, QueueEmpty from sleekxmpp.thirdparty.statemachine import StateMachine from sleekxmpp.xmlstream import Scheduler, tostring, cert from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase @@ -215,6 +212,10 @@ class XMLStream(object): #: If set to ``True``, attempt to use IPv6. self.use_ipv6 = True + #: Use CDATA for escaping instead of XML entities. Defaults + #: to ``False``. + self.use_cdata = False + #: An optional dictionary of proxy settings. It may provide: #: :host: The host offering proxy services. #: :port: The port for the proxy service. @@ -270,10 +271,10 @@ class XMLStream(object): self.end_session_on_disconnect = True #: A queue of stream, custom, and scheduled events to be processed. - self.event_queue = queue.Queue() + self.event_queue = Queue() #: A queue of string data to be sent over the stream. - self.send_queue = queue.Queue() + self.send_queue = Queue() self.send_queue_lock = threading.Lock() self.send_lock = threading.RLock() @@ -1586,7 +1587,7 @@ class XMLStream(object): try: wait = self.wait_timeout event = self.event_queue.get(True, timeout=wait) - except queue.Empty: + except QueueEmpty: event = None if event is None: continue @@ -1655,7 +1656,7 @@ class XMLStream(object): else: try: data = self.send_queue.get(True, 1) - except queue.Empty: + except QueueEmpty: continue log.debug("SEND: %s", data) enc_data = data.encode('utf-8') diff --git a/tests/test_jid.py b/tests/test_jid.py index ef1145d3..aeb635a1 100644 --- a/tests/test_jid.py +++ b/tests/test_jid.py @@ -1,5 +1,5 @@ from sleekxmpp.test import * -from sleekxmpp.xmlstream.jid import JID +from sleekxmpp import JID, InvalidJID class TestJIDClass(SleekTest): @@ -137,5 +137,146 @@ class TestJIDClass(SleekTest): self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal") self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal") + def testZeroLengthDomain(self): + self.assertRaises(InvalidJID, JID, domain='') + self.assertRaises(InvalidJID, JID, 'user@/resource') + + def testZeroLengthLocalPart(self): + self.assertRaises(InvalidJID, JID, local='', domain='test.com') + self.assertRaises(InvalidJID, JID, '@/test.com') + + def testZeroLengthResource(self): + self.assertRaises(InvalidJID, JID, domain='test.com', resource='') + self.assertRaises(InvalidJID, JID, 'test.com/') + + def test1023LengthDomain(self): + domain = ('a.' * 509) + 'a.com' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + def test1023LengthLocalPart(self): + local = 'a' * 1023 + jid1 = JID(local=local, domain='test.com') + jid2 = JID('%s@test.com' % local) + + def test1023LengthResource(self): + resource = 'r' * 1023 + jid1 = JID(domain='test.com', resource=resource) + jid2 = JID('test.com/%s' % resource) + + def test1024LengthDomain(self): + domain = ('a.' * 509) + 'aa.com' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def test1024LengthLocalPart(self): + local = 'a' * 1024 + self.assertRaises(InvalidJID, JID, local=local, domain='test.com') + self.assertRaises(InvalidJID, JID, '%s@/test.com' % local) + + def test1024LengthResource(self): + resource = 'r' * 1024 + self.assertRaises(InvalidJID, JID, domain='test.com', resource=resource) + self.assertRaises(InvalidJID, JID, 'test.com/%s' % resource) + + def testTooLongDomainLabel(self): + domain = ('a' * 64) + '.com' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testDomainEmptyLabel(self): + domain = 'aaa..bbb.com' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testDomainIPv4(self): + domain = '127.0.0.1' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + def testDomainIPv6(self): + domain = '[::1]' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + def testDomainInvalidIPv6NoBrackets(self): + domain = '::1' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + self.assertEqual(jid1.domain, '[::1]') + self.assertEqual(jid2.domain, '[::1]') + + def testDomainInvalidIPv6MissingBracket(self): + domain = '[::1' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + self.assertEqual(jid1.domain, '[::1]') + self.assertEqual(jid2.domain, '[::1]') + + def testDomainWithPort(self): + domain = 'example.com:5555' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testDomainWithTrailingDot(self): + domain = 'example.com.' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + self.assertEqual(jid1.domain, 'example.com') + self.assertEqual(jid2.domain, 'example.com') + + def testDomainWithDashes(self): + domain = 'example.com-' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + domain = '-example.com' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testACEDomain(self): + domain = 'xn--bcher-kva.ch' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + self.assertEqual(jid1.domain.encode('utf-8'), b'b\xc3\xbccher.ch') + self.assertEqual(jid2.domain.encode('utf-8'), b'b\xc3\xbccher.ch') + + def testJIDEscapeExistingSequences(self): + jid = JID(local='blah\\foo\\20bar', domain='example.com') + self.assertEqual(jid.local, 'blah\\foo\\5c20bar') + + def testJIDEscape(self): + jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")', + domain='example.com') + self.assertEqual(jid.local, r'here\27s_a_wild_\26_\2fcr%zy\2f_address_for\3a\3cwv\3e(\22IMPS\22)') + + def testJIDUnescape(self): + jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")', + domain='example.com') + ujid = jid.unescape() + self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")') + + jid = JID(local='blah\\foo\\20bar', domain='example.com') + ujid = jid.unescape() + self.assertEqual(ujid.local, 'blah\\foo\\20bar') + + def testStartOrEndWithEscapedSpaces(self): + local = ' foo' + self.assertRaises(InvalidJID, JID, local=local, domain='example.com') + self.assertRaises(InvalidJID, JID, '%s@example.com' % local) + + local = 'bar ' + self.assertRaises(InvalidJID, JID, local=local, domain='example.com') + self.assertRaises(InvalidJID, JID, '%s@example.com' % local) + + # Need more input for these cases. A JID starting with \20 *is* valid + # according to RFC 6122, but is not according to XEP-0106. + #self.assertRaises(InvalidJID, JID, '%s@example.com' % '\\20foo2') + #self.assertRaises(InvalidJID, JID, '%s@example.com' % 'bar2\\20') + suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass) diff --git a/tests/test_tostring.py b/tests/test_tostring.py index e456d28e..cd50a7c1 100644 --- a/tests/test_tostring.py +++ b/tests/test_tostring.py @@ -1,7 +1,7 @@ from sleekxmpp.test import * from sleekxmpp.stanza import Message from sleekxmpp.xmlstream.stanzabase import ET, ElementBase -from sleekxmpp.xmlstream.tostring import tostring, xml_escape +from sleekxmpp.xmlstream.tostring import tostring, escape class TestToString(SleekTest): @@ -30,7 +30,7 @@ class TestToString(SleekTest): def testXMLEscape(self): """Test escaping XML special characters.""" original = """<foo bar="baz">'Hi & welcome!'</foo>""" - escaped = xml_escape(original) + escaped = escape(original) desired = """<foo bar="baz">'Hi""" desired += """ & welcome!'</foo>""" |