diff options
-rw-r--r-- | INSTALL | 5 | ||||
-rwxr-xr-x | setup.py | 9 | ||||
-rw-r--r-- | slixmpp/basexmpp.py | 6 | ||||
-rw-r--r-- | slixmpp/features/feature_bind/bind.py | 2 | ||||
-rw-r--r-- | slixmpp/jid.py | 469 | ||||
-rw-r--r-- | slixmpp/plugins/xep_0078/legacyauth.py | 5 | ||||
-rw-r--r-- | slixmpp/stringprep.py | 105 | ||||
-rw-r--r-- | slixmpp/stringprep.pyx | 71 | ||||
-rw-r--r-- | tests/test_jid.py | 116 |
9 files changed, 395 insertions, 393 deletions
@@ -1,5 +1,6 @@ Pre-requisites: -- Python 3.1 or 2.6 +- Python 3.4 +- Cython 0.22 and libidn, optionally (making JID faster by compiling the stringprep module) Install: > python3 setup.py install @@ -9,4 +10,4 @@ Root install: To test: > cd examples -> python echo_client.py -v -j [USER@example.com] -p [PASSWORD] +> python3 echo_client.py -d -j [USER@example.com] -p [PASSWORD] @@ -13,6 +13,14 @@ try: except ImportError: from distutils.core import setup +try: + from Cython.Build import cythonize +except ImportError: + print('Cython not found, falling back to the slow stringprep module.') + ext_modules = None +else: + ext_modules = cythonize('slixmpp/stringprep.pyx') + from run_tests import TestCommand from slixmpp.version import __version__ @@ -43,6 +51,7 @@ setup( license='MIT', platforms=['any'], packages=packages, + ext_modules=ext_modules, requires=['aiodns', 'pyasn1', 'pyasn1_modules'], classifiers=CLASSIFIERS, cmdclass={'test': TestCommand} diff --git a/slixmpp/basexmpp.py b/slixmpp/basexmpp.py index f60ba560..80699319 100644 --- a/slixmpp/basexmpp.py +++ b/slixmpp/basexmpp.py @@ -57,12 +57,12 @@ class BaseXMPP(XMLStream): self.stream_id = None #: The JabberID (JID) requested for this connection. - self.requested_jid = JID(jid, cache_lock=True) + self.requested_jid = JID(jid) #: The JabberID (JID) used by this connection, #: as set after session binding. This may even be a #: different bare JID than what was requested. - self.boundjid = JID(jid, cache_lock=True) + self.boundjid = JID(jid) self._expected_server_name = self.boundjid.host self._redirect_attempts = 0 @@ -638,7 +638,7 @@ class BaseXMPP(XMLStream): def set_jid(self, jid): """Rip a JID apart and claim it as our own.""" log.debug("setting jid to %s", jid) - self.boundjid = JID(jid, cache_lock=True) + self.boundjid = JID(jid) def getjidresource(self, fulljid): if '/' in fulljid: diff --git a/slixmpp/features/feature_bind/bind.py b/slixmpp/features/feature_bind/bind.py index 25c99948..c031ab72 100644 --- a/slixmpp/features/feature_bind/bind.py +++ b/slixmpp/features/feature_bind/bind.py @@ -52,7 +52,7 @@ class FeatureBind(BasePlugin): iq.send(callback=self._on_bind_response) def _on_bind_response(self, response): - self.xmpp.boundjid = JID(response['bind']['jid'], cache_lock=True) + self.xmpp.boundjid = JID(response['bind']['jid']) self.xmpp.bound = True self.xmpp.event('session_bind', self.xmpp.boundjid) self.xmpp.session_bind_event.set() diff --git a/slixmpp/jid.py b/slixmpp/jid.py index 1df7d1f0..2e23e242 100644 --- a/slixmpp/jid.py +++ b/slixmpp/jid.py @@ -11,24 +11,15 @@ :license: MIT, see LICENSE for more details """ -from __future__ import unicode_literals - import re import socket -import stringprep -import threading -import encodings.idna from copy import deepcopy +from functools import lru_cache -from slixmpp.util import stringprep_profiles -from collections import OrderedDict +from slixmpp.stringprep import nodeprep, resourceprep, idna, StringprepError -#: These characters are not allowed to appear in a JID. -ILLEGAL_CHARS = '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r' + \ - '\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19' + \ - '\x1a\x1b\x1c\x1d\x1e\x1f' + \ - ' !"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~\x7f' +HAVE_INET_PTON = hasattr(socket, 'inet_pton') #: The basic regex pattern that a JID must match in order to determine #: the local, domain, and resource parts. This regex does NOT do any @@ -38,22 +29,8 @@ JID_PATTERN = re.compile( ) #: The set of escape sequences for the characters not allowed by nodeprep. -JID_ESCAPE_SEQUENCES = set(['\\20', '\\22', '\\26', '\\27', '\\2f', - '\\3a', '\\3c', '\\3e', '\\40', '\\5c']) - -#: A mapping of unallowed characters to their escape sequences. An escape -#: sequence for '\' is also included since it must also be escaped in -#: certain situations. -JID_ESCAPE_TRANSFORMATIONS = {' ': '\\20', - '"': '\\22', - '&': '\\26', - "'": '\\27', - '/': '\\2f', - ':': '\\3a', - '<': '\\3c', - '>': '\\3e', - '@': '\\40', - '\\': '\\5c'} +JID_ESCAPE_SEQUENCES = {'\\20', '\\22', '\\26', '\\27', '\\2f', + '\\3a', '\\3c', '\\3e', '\\40', '\\5c'} #: The reverse mapping of escape sequences to their original forms. JID_UNESCAPE_TRANSFORMATIONS = {'\\20': ' ', @@ -67,70 +44,9 @@ JID_UNESCAPE_TRANSFORMATIONS = {'\\20': ' ', '\\40': '@', '\\5c': '\\'} -JID_CACHE = OrderedDict() -JID_CACHE_LOCK = threading.Lock() -JID_CACHE_MAX_SIZE = 1024 - -def _cache(key, parts, locked): - JID_CACHE[key] = (parts, locked) - if len(JID_CACHE) > JID_CACHE_MAX_SIZE: - with JID_CACHE_LOCK: - while len(JID_CACHE) > JID_CACHE_MAX_SIZE: - found = None - for key, item in JID_CACHE.items(): - if not item[1]: # if not locked - found = key - break - if not found: # more than MAX_SIZE locked - # warn? - break - del JID_CACHE[found] - -# pylint: disable=c0103 -#: The nodeprep profile of stringprep used to validate the local, -#: or username, portion of a JID. -nodeprep = stringprep_profiles.create( - nfkc=True, - bidi=True, - mappings=[ - stringprep_profiles.b1_mapping, - stringprep.map_table_b2], - prohibited=[ - stringprep.in_table_c11, - stringprep.in_table_c12, - stringprep.in_table_c21, - stringprep.in_table_c22, - stringprep.in_table_c3, - stringprep.in_table_c4, - stringprep.in_table_c5, - stringprep.in_table_c6, - stringprep.in_table_c7, - stringprep.in_table_c8, - stringprep.in_table_c9, - lambda c: c in ' \'"&/:<>@'], - unassigned=[stringprep.in_table_a1]) - -# pylint: disable=c0103 -#: The resourceprep profile of stringprep, which is used to validate -#: the resource portion of a JID. -resourceprep = stringprep_profiles.create( - nfkc=True, - bidi=True, - mappings=[stringprep_profiles.b1_mapping], - prohibited=[ - stringprep.in_table_c12, - stringprep.in_table_c21, - stringprep.in_table_c22, - stringprep.in_table_c3, - stringprep.in_table_c4, - stringprep.in_table_c5, - stringprep.in_table_c6, - stringprep.in_table_c7, - stringprep.in_table_c8, - stringprep.in_table_c9], - unassigned=[stringprep.in_table_a1]) - +# TODO: Find the best cache size for a standard usage. +@lru_cache(maxsize=1024) def _parse_jid(data): """ Parse string data into the node, domain, and resource @@ -162,17 +78,19 @@ def _validate_node(node): :returns: The local portion of a JID, as validated by nodeprep. """ + if node is None: + return None + try: - if node is not None: - node = nodeprep(node) + node = nodeprep(node) + except StringprepError: + raise InvalidJID('Nodeprep failed') - if not node: - raise InvalidJID('Localpart must not be 0 bytes') - if len(node) > 1023: - raise InvalidJID('Localpart must be less than 1024 bytes') - return node - except stringprep_profiles.StringPrepError: - raise InvalidJID('Invalid local part') + if not node: + raise InvalidJID('Localpart must not be 0 bytes') + if len(node) > 1023: + raise InvalidJID('Localpart must be less than 1024 bytes') + return node def _validate_domain(domain): @@ -199,10 +117,10 @@ def _validate_domain(domain): pass # Check if this is an IPv6 address - if not ip_addr and hasattr(socket, 'inet_pton'): + if not ip_addr and HAVE_INET_PTON and domain[0] == '[' and domain[-1] == ']': try: - socket.inet_pton(socket.AF_INET6, domain.strip('[]')) - domain = '[%s]' % domain.strip('[]') + ip = domain[1:-1] + socket.inet_pton(socket.AF_INET6, ip) ip_addr = True except (socket.error, ValueError): pass @@ -213,31 +131,19 @@ def _validate_domain(domain): if domain and domain[-1] == '.': domain = domain[:-1] - domain_parts = [] - for label in domain.split('.'): - try: - label = encodings.idna.nameprep(label) - encodings.idna.ToASCII(label) - pass_nameprep = True - except UnicodeError: - pass_nameprep = False - - if not pass_nameprep: - raise InvalidJID('Could not encode domain as ASCII') - - if label.startswith('xn--'): - label = encodings.idna.ToUnicode(label) - - for char in label: - if char in ILLEGAL_CHARS: - raise InvalidJID('Domain contains illegal characters') + try: + domain = idna(domain) + except StringprepError: + raise InvalidJID('idna validation failed') + if ':' in domain: + raise InvalidJID('Domain containing a port') + for label in domain.split('.'): + if not label: + raise InvalidJID('Domain containing too many dots') if '-' in (label[0], label[-1]): raise InvalidJID('Domain started or ended with -') - domain_parts.append(label) - domain = '.'.join(domain_parts) - if not domain: raise InvalidJID('Domain must not be 0 bytes') if len(domain) > 1023: @@ -253,42 +159,19 @@ def _validate_resource(resource): :returns: The local portion of a JID, as validated by resourceprep. """ - try: - if resource is not None: - resource = resourceprep(resource) + if resource is None: + return None - if not resource: - raise InvalidJID('Resource must not be 0 bytes') - if len(resource) > 1023: - raise InvalidJID('Resource must be less than 1024 bytes') - return resource - except stringprep_profiles.StringPrepError: - raise InvalidJID('Invalid resource') - - -def _escape_node(node): - """Escape the local portion of a JID.""" - result = [] - - for i, char in enumerate(node): - if char == '\\': - if ''.join((node[i:i+3])) in JID_ESCAPE_SEQUENCES: - result.append('\\5c') - continue - result.append(char) - - for i, char in enumerate(result): - if char != '\\': - result[i] = JID_ESCAPE_TRANSFORMATIONS.get(char, char) - - escaped = ''.join(result) - - if escaped.startswith('\\20') or escaped.endswith('\\20'): - raise InvalidJID('Escaped local part starts or ends with "\\20"') - - _validate_node(escaped) + try: + resource = resourceprep(resource) + except StringprepError: + raise InvalidJID('Resourceprep failed') - return escaped + if not resource: + raise InvalidJID('Resource must not be 0 bytes') + if len(resource) > 1023: + raise InvalidJID('Resource must be less than 1024 bytes') + return resource def _unescape_node(node): @@ -313,9 +196,7 @@ def _unescape_node(node): seq = seq[1:] else: unescaped.append(char) - unescaped = ''.join(unescaped) - - return unescaped + return ''.join(unescaped) def _format_jid(local=None, domain=None, resource=None): @@ -328,12 +209,12 @@ def _format_jid(local=None, domain=None, resource=None): :return: A full or bare JID string. """ result = [] - if local: + if local is not None: result.append(local) result.append('@') - if domain: + if domain is not None: result.append(domain) - if resource: + if resource is not None: result.append('/') result.append(resource) return ''.join(result) @@ -349,47 +230,47 @@ class InvalidJID(ValueError): """ # pylint: disable=R0903 -class UnescapedJID(object): +class UnescapedJID: """ .. versionadded:: 1.1.10 """ - def __init__(self, local, domain, resource): - self._jid = (local, domain, resource) + __slots__ = ('_node', '_domain', '_resource') + + def __init__(self, node, domain, resource): + self._node = node + self._domain = domain + self._resource = resource - # pylint: disable=R0911 - def __getattr__(self, name): + def __getattribute__(self, name): """Retrieve the given JID component. :param name: one of: user, server, domain, resource, full, or bare. """ if name == 'resource': - return self._jid[2] or '' - elif name in ('user', 'username', 'local', 'node'): - return self._jid[0] or '' - elif name in ('server', 'domain', 'host'): - return self._jid[1] or '' - elif name in ('full', 'jid'): - return _format_jid(*self._jid) - elif name == 'bare': - return _format_jid(self._jid[0], self._jid[1]) - elif name == '_jid': - return getattr(super(JID, self), '_jid') - else: - return None + return self._resource or '' + if name in ('user', 'username', 'local', 'node'): + return self._node or '' + if name in ('server', 'domain', 'host'): + return self._domain or '' + if name in ('full', 'jid'): + return _format_jid(self._node, self._domain, self._resource) + if name == 'bare': + return _format_jid(self._node, self._domain) + return object.__getattribute__(self, name) def __str__(self): """Use the full JID as the string value.""" - return _format_jid(*self._jid) + return _format_jid(self._node, self._domain, self._resource) def __repr__(self): """Use the full JID as the representation.""" - return self.__str__() + return _format_jid(self._node, self._domain, self._resource) -class JID(object): +class JID: """ A representation of a Jabber ID, or JID. @@ -401,13 +282,13 @@ class JID(object): The JID is a full JID otherwise. **JID Properties:** - :jid: Alias for ``full``. :full: The string value of the full JID. + :jid: Alias for ``full``. :bare: The string value of the bare JID. - :user: The username portion of the JID. - :username: Alias for ``user``. - :local: Alias for ``user``. - :node: Alias for ``user``. + :node: The node portion of the JID. + :user: Alias for ``node``. + :local: Alias for ``node``. + :username: Alias for ``node``. :domain: The domain name portion of the JID. :server: Alias for ``domain``. :host: Alias for ``domain``. @@ -415,67 +296,23 @@ class JID(object): :param string jid: A string of the form ``'[user@]domain[/resource]'``. - :param string local: - Optional. Specify the local, or username, portion - of the JID. If provided, it will override the local - value provided by the `jid` parameter. The given - local value will also be escaped if necessary. - :param string domain: - Optional. Specify the domain of the JID. If - provided, it will override the domain given by - the `jid` parameter. - :param string resource: - Optional. Specify the resource value of the JID. - If provided, it will override the domain given - by the `jid` parameter. :raises InvalidJID: """ - # pylint: disable=W0212 - def __init__(self, jid=None, **kwargs): - locked = kwargs.get('cache_lock', False) - in_local = kwargs.get('local', None) - in_domain = kwargs.get('domain', None) - in_resource = kwargs.get('resource', None) - parts = None - if in_local or in_domain or in_resource: - parts = (in_local, in_domain, in_resource) - - # only check cache if there is a jid string, or parts, not if there - # are both - self._jid = None - key = None - if (jid is not None) and (parts is None): - if isinstance(jid, JID): - # it's already good to go, and there are no additions - self._jid = jid._jid - return - key = jid - self._jid, locked = JID_CACHE.get(jid, (None, locked)) - elif jid is None and parts is not None: - key = parts - self._jid, locked = JID_CACHE.get(parts, (None, locked)) - if not self._jid: - if not jid: - parsed_jid = (None, None, None) - elif not isinstance(jid, JID): - parsed_jid = _parse_jid(jid) - else: - parsed_jid = jid._jid - - local, domain, resource = parsed_jid - - if 'local' in kwargs: - local = _escape_node(in_local) - if 'domain' in kwargs: - domain = _validate_domain(in_domain) - if 'resource' in kwargs: - resource = _validate_resource(in_resource) - - self._jid = (local, domain, resource) - if key: - _cache(key, self._jid, locked) + __slots__ = ('_node', '_domain', '_resource') + + def __init__(self, jid=None): + if not jid: + self._node = None + self._domain = None + self._resource = None + elif not isinstance(jid, JID): + self._node, self._domain, self._resource = _parse_jid(jid) + else: + self._node = jid._node + self._domain = jid._domain + self._resource = jid._resource def unescape(self): """Return an unescaped JID object. @@ -488,151 +325,125 @@ class JID(object): .. versionadded:: 1.1.10 """ - return UnescapedJID(_unescape_node(self._jid[0]), - self._jid[1], - self._jid[2]) - - def regenerate(self): - """No-op - - .. deprecated:: 1.1.10 - """ - pass - - def reset(self, data): - """Start fresh from a new JID string. - - :param string data: A string of the form ``'[user@]domain[/resource]'``. - - .. deprecated:: 1.1.10 - """ - self._jid = JID(data)._jid + return UnescapedJID(_unescape_node(self._node), + self._domain, + self._resource) @property - def resource(self): - return self._jid[2] or '' + def node(self): + return self._node or '' @property def user(self): - return self._jid[0] or '' + return self._node or '' @property def local(self): - return self._jid[0] or '' - - @property - def node(self): - return self._jid[0] or '' + return self._node or '' @property def username(self): - return self._jid[0] or '' + return self._node or '' @property - def bare(self): - return _format_jid(self._jid[0], self._jid[1]) + def domain(self): + return self._domain or '' @property def server(self): - return self._jid[1] or '' - - @property - def domain(self): - return self._jid[1] or '' + return self._domain or '' @property def host(self): - return self._jid[1] or '' + return self._domain or '' @property - def full(self): - return _format_jid(*self._jid) + def resource(self): + return self._resource or '' @property - def jid(self): - return _format_jid(*self._jid) + def bare(self): + return _format_jid(self._node, self._domain) @property - def bare(self): - return _format_jid(self._jid[0], self._jid[1]) + def full(self): + return _format_jid(self._node, self._domain, self._resource) + @property + def jid(self): + return _format_jid(self._node, self._domain, self._resource) - @resource.setter - def resource(self, value): - self._jid = JID(self, resource=value)._jid + @node.setter + def node(self, value): + self._node = _validate_node(value) @user.setter def user(self, value): - self._jid = JID(self, local=value)._jid - - @username.setter - def username(self, value): - self._jid = JID(self, local=value)._jid + self._node = _validate_node(value) @local.setter def local(self, value): - self._jid = JID(self, local=value)._jid + self._node = _validate_node(value) - @node.setter - def node(self, value): - self._jid = JID(self, local=value)._jid - - @server.setter - def server(self, value): - self._jid = JID(self, domain=value)._jid + @username.setter + def username(self, value): + self._node = _validate_node(value) @domain.setter def domain(self, value): - self._jid = JID(self, domain=value)._jid + self._domain = _validate_domain(value) + + @server.setter + def server(self, value): + self._domain = _validate_domain(value) @host.setter def host(self, value): - self._jid = JID(self, domain=value)._jid + self._domain = _validate_domain(value) + + @bare.setter + def bare(self, value): + node, domain, resource = _parse_jid(value) + assert not resource + self._node = node + self._domain = domain + + @resource.setter + def resource(self, value): + self._resource = _validate_resource(value) @full.setter def full(self, value): - self._jid = JID(value)._jid + self._node, self._domain, self._resource = _parse_jid(value) @jid.setter def jid(self, value): - self._jid = JID(value)._jid - - @bare.setter - def bare(self, value): - parsed = JID(value)._jid - self._jid = (parsed[0], parsed[1], self._jid[2]) - + self._node, self._domain, self._resource = _parse_jid(value) def __str__(self): """Use the full JID as the string value.""" - return _format_jid(*self._jid) + return _format_jid(self._node, self._domain, self._resource) def __repr__(self): """Use the full JID as the representation.""" - return self.__str__() + return _format_jid(self._node, self._domain, self._resource) # pylint: disable=W0212 def __eq__(self, other): """Two JIDs are equal if they have the same full JID value.""" if isinstance(other, UnescapedJID): return False + if not isinstance(other, JID): + other = JID(other) - other = JID(other) - return self._jid == other._jid + return (self._node == other._node and + self._domain == other._domain and + self._resource == other._resource) - # pylint: disable=W0212 def __ne__(self, other): """Two JIDs are considered unequal if they are not equal.""" return not self == other def __hash__(self): """Hash a JID based on the string version of its full JID.""" - return hash(self.__str__()) - - def __copy__(self): - """Generate a duplicate JID.""" - return JID(self) - - def __deepcopy__(self, memo): - """Generate a duplicate JID.""" - return JID(deepcopy(str(self), memo)) + return hash(_format_jid(self._node, self._domain, self._resource)) diff --git a/slixmpp/plugins/xep_0078/legacyauth.py b/slixmpp/plugins/xep_0078/legacyauth.py index 0bcfb3d0..d949a913 100644 --- a/slixmpp/plugins/xep_0078/legacyauth.py +++ b/slixmpp/plugins/xep_0078/legacyauth.py @@ -128,9 +128,8 @@ class XEP_0078(BasePlugin): self.xmpp.authenticated = True - self.xmpp.boundjid = JID(self.xmpp.requested_jid, - resource=resource, - cache_lock=True) + self.xmpp.boundjid = JID(self.xmpp.requested_jid) + self.xmpp.boundjid.resource = resource self.xmpp.event('session_bind', self.xmpp.boundjid) log.debug("Established Session") diff --git a/slixmpp/stringprep.py b/slixmpp/stringprep.py new file mode 100644 index 00000000..e0757ef2 --- /dev/null +++ b/slixmpp/stringprep.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +""" + slixmpp.stringprep + ~~~~~~~~~~~~~~~~~~~~~~~ + + This module is a fallback using python’s stringprep instead of libidn’s. + + Part of Slixmpp: The Slick XMPP Library + + :copyright: (c) 2015 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr> + :license: MIT, see LICENSE for more details +""" + +import logging +import stringprep +from slixmpp.util import stringprep_profiles +import encodings.idna + +class StringprepError(Exception): + pass + +#: These characters are not allowed to appear in a domain part. +ILLEGAL_CHARS = ('\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r' + '\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19' + '\x1a\x1b\x1c\x1d\x1e\x1f' + ' !"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~\x7f') + + +# pylint: disable=c0103 +#: The nodeprep profile of stringprep used to validate the local, +#: or username, portion of a JID. +_nodeprep = stringprep_profiles.create( + nfkc=True, + bidi=True, + mappings=[ + stringprep_profiles.b1_mapping, + stringprep.map_table_b2], + prohibited=[ + stringprep.in_table_c11, + stringprep.in_table_c12, + stringprep.in_table_c21, + stringprep.in_table_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9, + lambda c: c in ' \'"&/:<>@'], + unassigned=[stringprep.in_table_a1]) + +def nodeprep(node): + try: + return _nodeprep(node) + except stringprep_profiles.StringPrepError: + raise StringprepError + +# pylint: disable=c0103 +#: The resourceprep profile of stringprep, which is used to validate +#: the resource portion of a JID. +_resourceprep = stringprep_profiles.create( + nfkc=True, + bidi=True, + mappings=[stringprep_profiles.b1_mapping], + prohibited=[ + stringprep.in_table_c12, + stringprep.in_table_c21, + stringprep.in_table_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9], + unassigned=[stringprep.in_table_a1]) + +def resourceprep(resource): + try: + return _resourceprep(resource) + except stringprep_profiles.StringPrepError: + raise StringprepError + +def idna(domain): + domain_parts = [] + for label in domain.split('.'): + try: + label = encodings.idna.nameprep(label) + encodings.idna.ToASCII(label) + except UnicodeError: + raise StringprepError + + if label.startswith('xn--'): + label = encodings.idna.ToUnicode(label) + + for char in label: + if char in ILLEGAL_CHARS: + raise StringprepError + + domain_parts.append(label) + return '.'.join(domain_parts) + +logging.getLogger(__name__).warning('Using slower stringprep, consider ' + 'compiling the faster cython/libidn one.') diff --git a/slixmpp/stringprep.pyx b/slixmpp/stringprep.pyx new file mode 100644 index 00000000..e17c62c3 --- /dev/null +++ b/slixmpp/stringprep.pyx @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# cython: language_level = 3 +# distutils: libraries = idn +""" + slixmpp.stringprep + ~~~~~~~~~~~~~~~~~~~~~~~ + + This module wraps libidn’s stringprep and idna functions using Cython. + + Part of Slixmpp: The Slick XMPP Library + + :copyright: (c) 2015 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr> + :license: MIT, see LICENSE for more details +""" + +from libc.stdlib cimport free + + +# Those are Cython declarations for the C function we’ll be using. + +cdef extern from "stringprep.h" nogil: + int stringprep_profile(const char* in_, char** out, const char* profile, int flags) + +cdef extern from "idna.h" nogil: + int idna_to_ascii_8z(const char* in_, char** out, int flags) + int idna_to_unicode_8z8z(const char* in_, char** out, int flags) + + +class StringprepError(Exception): + pass + + +cdef str _stringprep(str in_, const char* profile): + """Python wrapper for libidn’s stringprep.""" + cdef char* out + ret = stringprep_profile(in_.encode('utf-8'), &out, profile, 0) + if ret != 0: + raise StringprepError(ret) + unicode_out = out.decode('utf-8') + free(out) + return unicode_out + +def nodeprep(str node): + """The nodeprep profile of stringprep used to validate the local, or + username, portion of a JID.""" + return _stringprep(node, 'Nodeprep') + +def resourceprep(str resource): + """The resourceprep profile of stringprep, which is used to validate the + resource portion of a JID.""" + return _stringprep(resource, 'Resourceprep') + +def idna(str domain): + """The idna conversion functions, which are used to validate the domain + portion of a JID.""" + + cdef char* ascii_domain + cdef char* utf8_domain + + ret = idna_to_ascii_8z(domain.encode('utf-8'), &ascii_domain, 0) + if ret != 0: + raise StringprepError(ret) + + ret = idna_to_unicode_8z8z(ascii_domain, &utf8_domain, 0) + free(ascii_domain) + if ret != 0: + raise StringprepError(ret) + + unicode_domain = utf8_domain.decode('utf-8') + free(utf8_domain) + return unicode_domain diff --git a/tests/test_jid.py b/tests/test_jid.py index e42305c2..1233eb37 100644 --- a/tests/test_jid.py +++ b/tests/test_jid.py @@ -138,143 +138,149 @@ class TestJIDClass(SlixTest): def testJIDInequality(self): jid1 = JID('user@domain/resource') jid2 = JID('otheruser@domain/resource') - self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal") - self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal") + self.assertFalse(jid1 == jid2, "Different JIDs are considered equal") + self.assertTrue(jid1 != jid2, "Different JIDs are considered equal") def testZeroLengthDomain(self): - self.assertRaises(InvalidJID, JID, domain='') + jid1 = JID('') + jid2 = JID() + self.assertTrue(jid1 == jid2, "Empty JIDs are not considered equal") + self.assertTrue(jid1.domain == '', "Empty JID’s domain part not empty") + self.assertTrue(jid1.full == '', "Empty JID’s full part not empty") + + self.assertRaises(InvalidJID, JID, 'user@') + self.assertRaises(InvalidJID, JID, '/resource') self.assertRaises(InvalidJID, JID, 'user@/resource') def testZeroLengthLocalPart(self): - self.assertRaises(InvalidJID, JID, local='', domain='test.com') + self.assertRaises(InvalidJID, JID, '@test.com') + self.assertRaises(InvalidJID, JID, '@test.com/resource') + + def testZeroLengthNodeDomain(self): self.assertRaises(InvalidJID, JID, '@/test.com') def testZeroLengthResource(self): - self.assertRaises(InvalidJID, JID, domain='test.com', resource='') self.assertRaises(InvalidJID, JID, 'test.com/') + self.assertRaises(InvalidJID, JID, 'user@test.com/') def test1023LengthDomain(self): domain = ('a.' * 509) + 'a.com' - jid1 = JID(domain=domain) - jid2 = JID('user@%s/resource' % domain) + jid = JID('user@%s/resource' % domain) def test1023LengthLocalPart(self): local = 'a' * 1023 - jid1 = JID(local=local, domain='test.com') - jid2 = JID('%s@test.com' % local) + jid = JID('%s@test.com' % local) def test1023LengthResource(self): resource = 'r' * 1023 - jid1 = JID(domain='test.com', resource=resource) - jid2 = JID('test.com/%s' % resource) + jid = JID('test.com/%s' % resource) def test1024LengthDomain(self): domain = ('a.' * 509) + 'aa.com' - self.assertRaises(InvalidJID, JID, domain=domain) self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + self.assertRaises(InvalidJID, JID, 'user@%s' % domain) + self.assertRaises(InvalidJID, JID, '%s/resource' % domain) + self.assertRaises(InvalidJID, JID, domain) def test1024LengthLocalPart(self): local = 'a' * 1024 - self.assertRaises(InvalidJID, JID, local=local, domain='test.com') - self.assertRaises(InvalidJID, JID, '%s@/test.com' % local) + self.assertRaises(InvalidJID, JID, '%s@test.com' % local) + self.assertRaises(InvalidJID, JID, '%s@test.com/resource' % local) def test1024LengthResource(self): resource = 'r' * 1024 - self.assertRaises(InvalidJID, JID, domain='test.com', resource=resource) self.assertRaises(InvalidJID, JID, 'test.com/%s' % resource) + self.assertRaises(InvalidJID, JID, 'user@test.com/%s' % resource) def testTooLongDomainLabel(self): domain = ('a' * 64) + '.com' - self.assertRaises(InvalidJID, JID, domain=domain) self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) def testDomainEmptyLabel(self): domain = 'aaa..bbb.com' - self.assertRaises(InvalidJID, JID, domain=domain) self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) def testDomainIPv4(self): domain = '127.0.0.1' - jid1 = JID(domain=domain) - jid2 = JID('user@%s/resource' % domain) + + jid1 = JID('%s' % domain) + jid2 = JID('user@%s' % domain) + jid3 = JID('%s/resource' % domain) + jid4 = JID('user@%s/resource' % domain) def testDomainIPv6(self): domain = '[::1]' - jid1 = JID(domain=domain) - jid2 = JID('user@%s/resource' % domain) + + jid1 = JID('%s' % domain) + jid2 = JID('user@%s' % domain) + jid3 = JID('%s/resource' % domain) + jid4 = JID('user@%s/resource' % domain) def testDomainInvalidIPv6NoBrackets(self): domain = '::1' - jid1 = JID(domain=domain) - jid2 = JID('user@%s/resource' % domain) - self.assertEqual(jid1.domain, '[::1]') - self.assertEqual(jid2.domain, '[::1]') + self.assertRaises(InvalidJID, JID, '%s' % domain) + self.assertRaises(InvalidJID, JID, 'user@%s' % domain) + self.assertRaises(InvalidJID, JID, '%s/resource' % domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) def testDomainInvalidIPv6MissingBracket(self): domain = '[::1' - jid1 = JID(domain=domain) - jid2 = JID('user@%s/resource' % domain) - self.assertEqual(jid1.domain, '[::1]') - self.assertEqual(jid2.domain, '[::1]') + self.assertRaises(InvalidJID, JID, '%s' % domain) + self.assertRaises(InvalidJID, JID, 'user@%s' % domain) + self.assertRaises(InvalidJID, JID, '%s/resource' % domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testDomainInvalidIPv6WrongBracket(self): + domain = '[::]1]' + + self.assertRaises(InvalidJID, JID, '%s' % domain) + self.assertRaises(InvalidJID, JID, 'user@%s' % domain) + self.assertRaises(InvalidJID, JID, '%s/resource' % domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) def testDomainWithPort(self): domain = 'example.com:5555' - self.assertRaises(InvalidJID, JID, domain=domain) + + self.assertRaises(InvalidJID, JID, '%s' % domain) + self.assertRaises(InvalidJID, JID, 'user@%s' % domain) + self.assertRaises(InvalidJID, JID, '%s/resource' % domain) self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) def testDomainWithTrailingDot(self): domain = 'example.com.' - jid1 = JID(domain=domain) - jid2 = JID('user@%s/resource' % domain) + jid = JID('user@%s/resource' % domain) - self.assertEqual(jid1.domain, 'example.com') - self.assertEqual(jid2.domain, 'example.com') + self.assertEqual(jid.domain, 'example.com') def testDomainWithDashes(self): domain = 'example.com-' - self.assertRaises(InvalidJID, JID, domain=domain) self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) domain = '-example.com' - self.assertRaises(InvalidJID, JID, domain=domain) self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) def testACEDomain(self): domain = 'xn--bcher-kva.ch' - jid1 = JID(domain=domain) - jid2 = JID('user@%s/resource' % domain) - - self.assertEqual(jid1.domain.encode('utf-8'), b'b\xc3\xbccher.ch') - self.assertEqual(jid2.domain.encode('utf-8'), b'b\xc3\xbccher.ch') - - def testJIDEscapeExistingSequences(self): - jid = JID(local='blah\\foo\\20bar', domain='example.com') - self.assertEqual(jid.local, 'blah\\foo\\5c20bar') + jid = JID('user@%s/resource' % domain) - def testJIDEscape(self): - jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")', - domain='example.com') - self.assertEqual(jid.local, r'here\27s_a_wild_\26_\2fcr%zy\2f_address_for\3a\3cwv\3e(\22IMPS\22)') + self.assertEqual(jid.domain.encode('utf-8'), b'b\xc3\xbccher.ch') def testJIDUnescape(self): - jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")', - domain='example.com') + jid = JID('here\\27s_a_wild_\\26_\\2fcr%zy\\2f_\\40ddress\\20for\\3a\\3cwv\\3e(\\22IMPS\\22)\\5c@example.com') ujid = jid.unescape() - self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")') + self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_@ddress for:<wv>("imps")\\') - jid = JID(local='blah\\foo\\20bar', domain='example.com') + jid = JID('blah\\5cfoo\\5c20bar@example.com') ujid = jid.unescape() self.assertEqual(ujid.local, 'blah\\foo\\20bar') def testStartOrEndWithEscapedSpaces(self): local = ' foo' - self.assertRaises(InvalidJID, JID, local=local, domain='example.com') self.assertRaises(InvalidJID, JID, '%s@example.com' % local) local = 'bar ' - self.assertRaises(InvalidJID, JID, local=local, domain='example.com') self.assertRaises(InvalidJID, JID, '%s@example.com' % local) # Need more input for these cases. A JID starting with \20 *is* valid |