diff options
Diffstat (limited to 'slixmpp')
-rw-r--r-- | slixmpp/basexmpp.py | 6 | ||||
-rw-r--r-- | slixmpp/features/feature_bind/bind.py | 2 | ||||
-rw-r--r-- | slixmpp/jid.py | 469 | ||||
-rw-r--r-- | slixmpp/plugins/xep_0078/legacyauth.py | 5 | ||||
-rw-r--r-- | slixmpp/stringprep.py | 105 | ||||
-rw-r--r-- | slixmpp/stringprep.pyx | 71 |
6 files changed, 322 insertions, 336 deletions
diff --git a/slixmpp/basexmpp.py b/slixmpp/basexmpp.py index f60ba560..80699319 100644 --- a/slixmpp/basexmpp.py +++ b/slixmpp/basexmpp.py @@ -57,12 +57,12 @@ class BaseXMPP(XMLStream): self.stream_id = None #: The JabberID (JID) requested for this connection. - self.requested_jid = JID(jid, cache_lock=True) + self.requested_jid = JID(jid) #: The JabberID (JID) used by this connection, #: as set after session binding. This may even be a #: different bare JID than what was requested. - self.boundjid = JID(jid, cache_lock=True) + self.boundjid = JID(jid) self._expected_server_name = self.boundjid.host self._redirect_attempts = 0 @@ -638,7 +638,7 @@ class BaseXMPP(XMLStream): def set_jid(self, jid): """Rip a JID apart and claim it as our own.""" log.debug("setting jid to %s", jid) - self.boundjid = JID(jid, cache_lock=True) + self.boundjid = JID(jid) def getjidresource(self, fulljid): if '/' in fulljid: diff --git a/slixmpp/features/feature_bind/bind.py b/slixmpp/features/feature_bind/bind.py index 25c99948..c031ab72 100644 --- a/slixmpp/features/feature_bind/bind.py +++ b/slixmpp/features/feature_bind/bind.py @@ -52,7 +52,7 @@ class FeatureBind(BasePlugin): iq.send(callback=self._on_bind_response) def _on_bind_response(self, response): - self.xmpp.boundjid = JID(response['bind']['jid'], cache_lock=True) + self.xmpp.boundjid = JID(response['bind']['jid']) self.xmpp.bound = True self.xmpp.event('session_bind', self.xmpp.boundjid) self.xmpp.session_bind_event.set() diff --git a/slixmpp/jid.py b/slixmpp/jid.py index 1df7d1f0..2e23e242 100644 --- a/slixmpp/jid.py +++ b/slixmpp/jid.py @@ -11,24 +11,15 @@ :license: MIT, see LICENSE for more details """ -from __future__ import unicode_literals - import re import socket -import stringprep -import threading -import encodings.idna from copy import deepcopy +from functools import lru_cache -from slixmpp.util import stringprep_profiles -from collections import OrderedDict +from slixmpp.stringprep import nodeprep, resourceprep, idna, StringprepError -#: These characters are not allowed to appear in a JID. -ILLEGAL_CHARS = '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r' + \ - '\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19' + \ - '\x1a\x1b\x1c\x1d\x1e\x1f' + \ - ' !"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~\x7f' +HAVE_INET_PTON = hasattr(socket, 'inet_pton') #: The basic regex pattern that a JID must match in order to determine #: the local, domain, and resource parts. This regex does NOT do any @@ -38,22 +29,8 @@ JID_PATTERN = re.compile( ) #: The set of escape sequences for the characters not allowed by nodeprep. -JID_ESCAPE_SEQUENCES = set(['\\20', '\\22', '\\26', '\\27', '\\2f', - '\\3a', '\\3c', '\\3e', '\\40', '\\5c']) - -#: A mapping of unallowed characters to their escape sequences. An escape -#: sequence for '\' is also included since it must also be escaped in -#: certain situations. -JID_ESCAPE_TRANSFORMATIONS = {' ': '\\20', - '"': '\\22', - '&': '\\26', - "'": '\\27', - '/': '\\2f', - ':': '\\3a', - '<': '\\3c', - '>': '\\3e', - '@': '\\40', - '\\': '\\5c'} +JID_ESCAPE_SEQUENCES = {'\\20', '\\22', '\\26', '\\27', '\\2f', + '\\3a', '\\3c', '\\3e', '\\40', '\\5c'} #: The reverse mapping of escape sequences to their original forms. JID_UNESCAPE_TRANSFORMATIONS = {'\\20': ' ', @@ -67,70 +44,9 @@ JID_UNESCAPE_TRANSFORMATIONS = {'\\20': ' ', '\\40': '@', '\\5c': '\\'} -JID_CACHE = OrderedDict() -JID_CACHE_LOCK = threading.Lock() -JID_CACHE_MAX_SIZE = 1024 - -def _cache(key, parts, locked): - JID_CACHE[key] = (parts, locked) - if len(JID_CACHE) > JID_CACHE_MAX_SIZE: - with JID_CACHE_LOCK: - while len(JID_CACHE) > JID_CACHE_MAX_SIZE: - found = None - for key, item in JID_CACHE.items(): - if not item[1]: # if not locked - found = key - break - if not found: # more than MAX_SIZE locked - # warn? - break - del JID_CACHE[found] - -# pylint: disable=c0103 -#: The nodeprep profile of stringprep used to validate the local, -#: or username, portion of a JID. -nodeprep = stringprep_profiles.create( - nfkc=True, - bidi=True, - mappings=[ - stringprep_profiles.b1_mapping, - stringprep.map_table_b2], - prohibited=[ - stringprep.in_table_c11, - stringprep.in_table_c12, - stringprep.in_table_c21, - stringprep.in_table_c22, - stringprep.in_table_c3, - stringprep.in_table_c4, - stringprep.in_table_c5, - stringprep.in_table_c6, - stringprep.in_table_c7, - stringprep.in_table_c8, - stringprep.in_table_c9, - lambda c: c in ' \'"&/:<>@'], - unassigned=[stringprep.in_table_a1]) - -# pylint: disable=c0103 -#: The resourceprep profile of stringprep, which is used to validate -#: the resource portion of a JID. -resourceprep = stringprep_profiles.create( - nfkc=True, - bidi=True, - mappings=[stringprep_profiles.b1_mapping], - prohibited=[ - stringprep.in_table_c12, - stringprep.in_table_c21, - stringprep.in_table_c22, - stringprep.in_table_c3, - stringprep.in_table_c4, - stringprep.in_table_c5, - stringprep.in_table_c6, - stringprep.in_table_c7, - stringprep.in_table_c8, - stringprep.in_table_c9], - unassigned=[stringprep.in_table_a1]) - +# TODO: Find the best cache size for a standard usage. +@lru_cache(maxsize=1024) def _parse_jid(data): """ Parse string data into the node, domain, and resource @@ -162,17 +78,19 @@ def _validate_node(node): :returns: The local portion of a JID, as validated by nodeprep. """ + if node is None: + return None + try: - if node is not None: - node = nodeprep(node) + node = nodeprep(node) + except StringprepError: + raise InvalidJID('Nodeprep failed') - if not node: - raise InvalidJID('Localpart must not be 0 bytes') - if len(node) > 1023: - raise InvalidJID('Localpart must be less than 1024 bytes') - return node - except stringprep_profiles.StringPrepError: - raise InvalidJID('Invalid local part') + if not node: + raise InvalidJID('Localpart must not be 0 bytes') + if len(node) > 1023: + raise InvalidJID('Localpart must be less than 1024 bytes') + return node def _validate_domain(domain): @@ -199,10 +117,10 @@ def _validate_domain(domain): pass # Check if this is an IPv6 address - if not ip_addr and hasattr(socket, 'inet_pton'): + if not ip_addr and HAVE_INET_PTON and domain[0] == '[' and domain[-1] == ']': try: - socket.inet_pton(socket.AF_INET6, domain.strip('[]')) - domain = '[%s]' % domain.strip('[]') + ip = domain[1:-1] + socket.inet_pton(socket.AF_INET6, ip) ip_addr = True except (socket.error, ValueError): pass @@ -213,31 +131,19 @@ def _validate_domain(domain): if domain and domain[-1] == '.': domain = domain[:-1] - domain_parts = [] - for label in domain.split('.'): - try: - label = encodings.idna.nameprep(label) - encodings.idna.ToASCII(label) - pass_nameprep = True - except UnicodeError: - pass_nameprep = False - - if not pass_nameprep: - raise InvalidJID('Could not encode domain as ASCII') - - if label.startswith('xn--'): - label = encodings.idna.ToUnicode(label) - - for char in label: - if char in ILLEGAL_CHARS: - raise InvalidJID('Domain contains illegal characters') + try: + domain = idna(domain) + except StringprepError: + raise InvalidJID('idna validation failed') + if ':' in domain: + raise InvalidJID('Domain containing a port') + for label in domain.split('.'): + if not label: + raise InvalidJID('Domain containing too many dots') if '-' in (label[0], label[-1]): raise InvalidJID('Domain started or ended with -') - domain_parts.append(label) - domain = '.'.join(domain_parts) - if not domain: raise InvalidJID('Domain must not be 0 bytes') if len(domain) > 1023: @@ -253,42 +159,19 @@ def _validate_resource(resource): :returns: The local portion of a JID, as validated by resourceprep. """ - try: - if resource is not None: - resource = resourceprep(resource) + if resource is None: + return None - if not resource: - raise InvalidJID('Resource must not be 0 bytes') - if len(resource) > 1023: - raise InvalidJID('Resource must be less than 1024 bytes') - return resource - except stringprep_profiles.StringPrepError: - raise InvalidJID('Invalid resource') - - -def _escape_node(node): - """Escape the local portion of a JID.""" - result = [] - - for i, char in enumerate(node): - if char == '\\': - if ''.join((node[i:i+3])) in JID_ESCAPE_SEQUENCES: - result.append('\\5c') - continue - result.append(char) - - for i, char in enumerate(result): - if char != '\\': - result[i] = JID_ESCAPE_TRANSFORMATIONS.get(char, char) - - escaped = ''.join(result) - - if escaped.startswith('\\20') or escaped.endswith('\\20'): - raise InvalidJID('Escaped local part starts or ends with "\\20"') - - _validate_node(escaped) + try: + resource = resourceprep(resource) + except StringprepError: + raise InvalidJID('Resourceprep failed') - return escaped + if not resource: + raise InvalidJID('Resource must not be 0 bytes') + if len(resource) > 1023: + raise InvalidJID('Resource must be less than 1024 bytes') + return resource def _unescape_node(node): @@ -313,9 +196,7 @@ def _unescape_node(node): seq = seq[1:] else: unescaped.append(char) - unescaped = ''.join(unescaped) - - return unescaped + return ''.join(unescaped) def _format_jid(local=None, domain=None, resource=None): @@ -328,12 +209,12 @@ def _format_jid(local=None, domain=None, resource=None): :return: A full or bare JID string. """ result = [] - if local: + if local is not None: result.append(local) result.append('@') - if domain: + if domain is not None: result.append(domain) - if resource: + if resource is not None: result.append('/') result.append(resource) return ''.join(result) @@ -349,47 +230,47 @@ class InvalidJID(ValueError): """ # pylint: disable=R0903 -class UnescapedJID(object): +class UnescapedJID: """ .. versionadded:: 1.1.10 """ - def __init__(self, local, domain, resource): - self._jid = (local, domain, resource) + __slots__ = ('_node', '_domain', '_resource') + + def __init__(self, node, domain, resource): + self._node = node + self._domain = domain + self._resource = resource - # pylint: disable=R0911 - def __getattr__(self, name): + def __getattribute__(self, name): """Retrieve the given JID component. :param name: one of: user, server, domain, resource, full, or bare. """ if name == 'resource': - return self._jid[2] or '' - elif name in ('user', 'username', 'local', 'node'): - return self._jid[0] or '' - elif name in ('server', 'domain', 'host'): - return self._jid[1] or '' - elif name in ('full', 'jid'): - return _format_jid(*self._jid) - elif name == 'bare': - return _format_jid(self._jid[0], self._jid[1]) - elif name == '_jid': - return getattr(super(JID, self), '_jid') - else: - return None + return self._resource or '' + if name in ('user', 'username', 'local', 'node'): + return self._node or '' + if name in ('server', 'domain', 'host'): + return self._domain or '' + if name in ('full', 'jid'): + return _format_jid(self._node, self._domain, self._resource) + if name == 'bare': + return _format_jid(self._node, self._domain) + return object.__getattribute__(self, name) def __str__(self): """Use the full JID as the string value.""" - return _format_jid(*self._jid) + return _format_jid(self._node, self._domain, self._resource) def __repr__(self): """Use the full JID as the representation.""" - return self.__str__() + return _format_jid(self._node, self._domain, self._resource) -class JID(object): +class JID: """ A representation of a Jabber ID, or JID. @@ -401,13 +282,13 @@ class JID(object): The JID is a full JID otherwise. **JID Properties:** - :jid: Alias for ``full``. :full: The string value of the full JID. + :jid: Alias for ``full``. :bare: The string value of the bare JID. - :user: The username portion of the JID. - :username: Alias for ``user``. - :local: Alias for ``user``. - :node: Alias for ``user``. + :node: The node portion of the JID. + :user: Alias for ``node``. + :local: Alias for ``node``. + :username: Alias for ``node``. :domain: The domain name portion of the JID. :server: Alias for ``domain``. :host: Alias for ``domain``. @@ -415,67 +296,23 @@ class JID(object): :param string jid: A string of the form ``'[user@]domain[/resource]'``. - :param string local: - Optional. Specify the local, or username, portion - of the JID. If provided, it will override the local - value provided by the `jid` parameter. The given - local value will also be escaped if necessary. - :param string domain: - Optional. Specify the domain of the JID. If - provided, it will override the domain given by - the `jid` parameter. - :param string resource: - Optional. Specify the resource value of the JID. - If provided, it will override the domain given - by the `jid` parameter. :raises InvalidJID: """ - # pylint: disable=W0212 - def __init__(self, jid=None, **kwargs): - locked = kwargs.get('cache_lock', False) - in_local = kwargs.get('local', None) - in_domain = kwargs.get('domain', None) - in_resource = kwargs.get('resource', None) - parts = None - if in_local or in_domain or in_resource: - parts = (in_local, in_domain, in_resource) - - # only check cache if there is a jid string, or parts, not if there - # are both - self._jid = None - key = None - if (jid is not None) and (parts is None): - if isinstance(jid, JID): - # it's already good to go, and there are no additions - self._jid = jid._jid - return - key = jid - self._jid, locked = JID_CACHE.get(jid, (None, locked)) - elif jid is None and parts is not None: - key = parts - self._jid, locked = JID_CACHE.get(parts, (None, locked)) - if not self._jid: - if not jid: - parsed_jid = (None, None, None) - elif not isinstance(jid, JID): - parsed_jid = _parse_jid(jid) - else: - parsed_jid = jid._jid - - local, domain, resource = parsed_jid - - if 'local' in kwargs: - local = _escape_node(in_local) - if 'domain' in kwargs: - domain = _validate_domain(in_domain) - if 'resource' in kwargs: - resource = _validate_resource(in_resource) - - self._jid = (local, domain, resource) - if key: - _cache(key, self._jid, locked) + __slots__ = ('_node', '_domain', '_resource') + + def __init__(self, jid=None): + if not jid: + self._node = None + self._domain = None + self._resource = None + elif not isinstance(jid, JID): + self._node, self._domain, self._resource = _parse_jid(jid) + else: + self._node = jid._node + self._domain = jid._domain + self._resource = jid._resource def unescape(self): """Return an unescaped JID object. @@ -488,151 +325,125 @@ class JID(object): .. versionadded:: 1.1.10 """ - return UnescapedJID(_unescape_node(self._jid[0]), - self._jid[1], - self._jid[2]) - - def regenerate(self): - """No-op - - .. deprecated:: 1.1.10 - """ - pass - - def reset(self, data): - """Start fresh from a new JID string. - - :param string data: A string of the form ``'[user@]domain[/resource]'``. - - .. deprecated:: 1.1.10 - """ - self._jid = JID(data)._jid + return UnescapedJID(_unescape_node(self._node), + self._domain, + self._resource) @property - def resource(self): - return self._jid[2] or '' + def node(self): + return self._node or '' @property def user(self): - return self._jid[0] or '' + return self._node or '' @property def local(self): - return self._jid[0] or '' - - @property - def node(self): - return self._jid[0] or '' + return self._node or '' @property def username(self): - return self._jid[0] or '' + return self._node or '' @property - def bare(self): - return _format_jid(self._jid[0], self._jid[1]) + def domain(self): + return self._domain or '' @property def server(self): - return self._jid[1] or '' - - @property - def domain(self): - return self._jid[1] or '' + return self._domain or '' @property def host(self): - return self._jid[1] or '' + return self._domain or '' @property - def full(self): - return _format_jid(*self._jid) + def resource(self): + return self._resource or '' @property - def jid(self): - return _format_jid(*self._jid) + def bare(self): + return _format_jid(self._node, self._domain) @property - def bare(self): - return _format_jid(self._jid[0], self._jid[1]) + def full(self): + return _format_jid(self._node, self._domain, self._resource) + @property + def jid(self): + return _format_jid(self._node, self._domain, self._resource) - @resource.setter - def resource(self, value): - self._jid = JID(self, resource=value)._jid + @node.setter + def node(self, value): + self._node = _validate_node(value) @user.setter def user(self, value): - self._jid = JID(self, local=value)._jid - - @username.setter - def username(self, value): - self._jid = JID(self, local=value)._jid + self._node = _validate_node(value) @local.setter def local(self, value): - self._jid = JID(self, local=value)._jid + self._node = _validate_node(value) - @node.setter - def node(self, value): - self._jid = JID(self, local=value)._jid - - @server.setter - def server(self, value): - self._jid = JID(self, domain=value)._jid + @username.setter + def username(self, value): + self._node = _validate_node(value) @domain.setter def domain(self, value): - self._jid = JID(self, domain=value)._jid + self._domain = _validate_domain(value) + + @server.setter + def server(self, value): + self._domain = _validate_domain(value) @host.setter def host(self, value): - self._jid = JID(self, domain=value)._jid + self._domain = _validate_domain(value) + + @bare.setter + def bare(self, value): + node, domain, resource = _parse_jid(value) + assert not resource + self._node = node + self._domain = domain + + @resource.setter + def resource(self, value): + self._resource = _validate_resource(value) @full.setter def full(self, value): - self._jid = JID(value)._jid + self._node, self._domain, self._resource = _parse_jid(value) @jid.setter def jid(self, value): - self._jid = JID(value)._jid - - @bare.setter - def bare(self, value): - parsed = JID(value)._jid - self._jid = (parsed[0], parsed[1], self._jid[2]) - + self._node, self._domain, self._resource = _parse_jid(value) def __str__(self): """Use the full JID as the string value.""" - return _format_jid(*self._jid) + return _format_jid(self._node, self._domain, self._resource) def __repr__(self): """Use the full JID as the representation.""" - return self.__str__() + return _format_jid(self._node, self._domain, self._resource) # pylint: disable=W0212 def __eq__(self, other): """Two JIDs are equal if they have the same full JID value.""" if isinstance(other, UnescapedJID): return False + if not isinstance(other, JID): + other = JID(other) - other = JID(other) - return self._jid == other._jid + return (self._node == other._node and + self._domain == other._domain and + self._resource == other._resource) - # pylint: disable=W0212 def __ne__(self, other): """Two JIDs are considered unequal if they are not equal.""" return not self == other def __hash__(self): """Hash a JID based on the string version of its full JID.""" - return hash(self.__str__()) - - def __copy__(self): - """Generate a duplicate JID.""" - return JID(self) - - def __deepcopy__(self, memo): - """Generate a duplicate JID.""" - return JID(deepcopy(str(self), memo)) + return hash(_format_jid(self._node, self._domain, self._resource)) diff --git a/slixmpp/plugins/xep_0078/legacyauth.py b/slixmpp/plugins/xep_0078/legacyauth.py index 0bcfb3d0..d949a913 100644 --- a/slixmpp/plugins/xep_0078/legacyauth.py +++ b/slixmpp/plugins/xep_0078/legacyauth.py @@ -128,9 +128,8 @@ class XEP_0078(BasePlugin): self.xmpp.authenticated = True - self.xmpp.boundjid = JID(self.xmpp.requested_jid, - resource=resource, - cache_lock=True) + self.xmpp.boundjid = JID(self.xmpp.requested_jid) + self.xmpp.boundjid.resource = resource self.xmpp.event('session_bind', self.xmpp.boundjid) log.debug("Established Session") diff --git a/slixmpp/stringprep.py b/slixmpp/stringprep.py new file mode 100644 index 00000000..e0757ef2 --- /dev/null +++ b/slixmpp/stringprep.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +""" + slixmpp.stringprep + ~~~~~~~~~~~~~~~~~~~~~~~ + + This module is a fallback using python’s stringprep instead of libidn’s. + + Part of Slixmpp: The Slick XMPP Library + + :copyright: (c) 2015 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr> + :license: MIT, see LICENSE for more details +""" + +import logging +import stringprep +from slixmpp.util import stringprep_profiles +import encodings.idna + +class StringprepError(Exception): + pass + +#: These characters are not allowed to appear in a domain part. +ILLEGAL_CHARS = ('\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r' + '\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19' + '\x1a\x1b\x1c\x1d\x1e\x1f' + ' !"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~\x7f') + + +# pylint: disable=c0103 +#: The nodeprep profile of stringprep used to validate the local, +#: or username, portion of a JID. +_nodeprep = stringprep_profiles.create( + nfkc=True, + bidi=True, + mappings=[ + stringprep_profiles.b1_mapping, + stringprep.map_table_b2], + prohibited=[ + stringprep.in_table_c11, + stringprep.in_table_c12, + stringprep.in_table_c21, + stringprep.in_table_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9, + lambda c: c in ' \'"&/:<>@'], + unassigned=[stringprep.in_table_a1]) + +def nodeprep(node): + try: + return _nodeprep(node) + except stringprep_profiles.StringPrepError: + raise StringprepError + +# pylint: disable=c0103 +#: The resourceprep profile of stringprep, which is used to validate +#: the resource portion of a JID. +_resourceprep = stringprep_profiles.create( + nfkc=True, + bidi=True, + mappings=[stringprep_profiles.b1_mapping], + prohibited=[ + stringprep.in_table_c12, + stringprep.in_table_c21, + stringprep.in_table_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9], + unassigned=[stringprep.in_table_a1]) + +def resourceprep(resource): + try: + return _resourceprep(resource) + except stringprep_profiles.StringPrepError: + raise StringprepError + +def idna(domain): + domain_parts = [] + for label in domain.split('.'): + try: + label = encodings.idna.nameprep(label) + encodings.idna.ToASCII(label) + except UnicodeError: + raise StringprepError + + if label.startswith('xn--'): + label = encodings.idna.ToUnicode(label) + + for char in label: + if char in ILLEGAL_CHARS: + raise StringprepError + + domain_parts.append(label) + return '.'.join(domain_parts) + +logging.getLogger(__name__).warning('Using slower stringprep, consider ' + 'compiling the faster cython/libidn one.') diff --git a/slixmpp/stringprep.pyx b/slixmpp/stringprep.pyx new file mode 100644 index 00000000..e17c62c3 --- /dev/null +++ b/slixmpp/stringprep.pyx @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# cython: language_level = 3 +# distutils: libraries = idn +""" + slixmpp.stringprep + ~~~~~~~~~~~~~~~~~~~~~~~ + + This module wraps libidn’s stringprep and idna functions using Cython. + + Part of Slixmpp: The Slick XMPP Library + + :copyright: (c) 2015 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr> + :license: MIT, see LICENSE for more details +""" + +from libc.stdlib cimport free + + +# Those are Cython declarations for the C function we’ll be using. + +cdef extern from "stringprep.h" nogil: + int stringprep_profile(const char* in_, char** out, const char* profile, int flags) + +cdef extern from "idna.h" nogil: + int idna_to_ascii_8z(const char* in_, char** out, int flags) + int idna_to_unicode_8z8z(const char* in_, char** out, int flags) + + +class StringprepError(Exception): + pass + + +cdef str _stringprep(str in_, const char* profile): + """Python wrapper for libidn’s stringprep.""" + cdef char* out + ret = stringprep_profile(in_.encode('utf-8'), &out, profile, 0) + if ret != 0: + raise StringprepError(ret) + unicode_out = out.decode('utf-8') + free(out) + return unicode_out + +def nodeprep(str node): + """The nodeprep profile of stringprep used to validate the local, or + username, portion of a JID.""" + return _stringprep(node, 'Nodeprep') + +def resourceprep(str resource): + """The resourceprep profile of stringprep, which is used to validate the + resource portion of a JID.""" + return _stringprep(resource, 'Resourceprep') + +def idna(str domain): + """The idna conversion functions, which are used to validate the domain + portion of a JID.""" + + cdef char* ascii_domain + cdef char* utf8_domain + + ret = idna_to_ascii_8z(domain.encode('utf-8'), &ascii_domain, 0) + if ret != 0: + raise StringprepError(ret) + + ret = idna_to_unicode_8z8z(ascii_domain, &utf8_domain, 0) + free(ascii_domain) + if ret != 0: + raise StringprepError(ret) + + unicode_domain = utf8_domain.decode('utf-8') + free(utf8_domain) + return unicode_domain |