summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sleekxmpp/__init__.py2
-rw-r--r--sleekxmpp/jid.py20
-rw-r--r--sleekxmpp/util/stringprep_profiles.py3
-rw-r--r--tests/test_jid.py143
4 files changed, 160 insertions, 8 deletions
diff --git a/sleekxmpp/__init__.py b/sleekxmpp/__init__.py
index 84b1114f..f0dc2ce2 100644
--- a/sleekxmpp/__init__.py
+++ b/sleekxmpp/__init__.py
@@ -10,7 +10,7 @@ from sleekxmpp.basexmpp import BaseXMPP
from sleekxmpp.clientxmpp import ClientXMPP
from sleekxmpp.componentxmpp import ComponentXMPP
from sleekxmpp.stanza import Message, Presence, Iq
-from sleekxmpp.jid import JID
+from sleekxmpp.jid import JID, InvalidJID
from sleekxmpp.xmlstream.handler import *
from sleekxmpp.xmlstream import XMLStream, RestartStream
from sleekxmpp.xmlstream.matcher import *
diff --git a/sleekxmpp/jid.py b/sleekxmpp/jid.py
index f0b7423b..9e9c0d0b 100644
--- a/sleekxmpp/jid.py
+++ b/sleekxmpp/jid.py
@@ -140,13 +140,12 @@ def _validate_node(node):
"""
try:
if node is not None:
- if not node:
- raise InvalidJID('Localpart must not be 0 bytes')
-
node = nodeprep(node)
if not node:
raise InvalidJID('Localpart must not be 0 bytes')
+ if len(node) > 1023:
+ raise InvalidJID('Localpart must be less than 1024 bytes')
return node
except stringprep_profiles.StringPrepError:
raise InvalidJID('Invalid local part')
@@ -179,6 +178,7 @@ def _validate_domain(domain):
if not ip_addr and hasattr(socket, 'inet_pton'):
try:
socket.inet_pton(socket.AF_INET6, domain.strip('[]'))
+ domain = '[%s]' % domain.strip('[]')
ip_addr = True
except socket.error:
pass
@@ -186,12 +186,19 @@ def _validate_domain(domain):
if not ip_addr:
# This is a domain name, which must be checked further
+ if domain and domain[-1] == '.':
+ domain = domain[:-1]
+
domain_parts = []
for label in domain.split('.'):
try:
label = encodings.idna.nameprep(label)
encodings.idna.ToASCII(label)
+ pass_nameprep = True
except UnicodeError:
+ pass_nameprep = False
+
+ if not pass_nameprep:
raise InvalidJID('Could not encode domain as ASCII')
if label.startswith('xn--'):
@@ -209,6 +216,8 @@ def _validate_domain(domain):
if not domain:
raise InvalidJID('Domain must not be 0 bytes')
+ if len(domain) > 1023:
+ raise InvalidJID('Domain must be less than 1024 bytes')
return domain
@@ -222,13 +231,12 @@ def _validate_resource(resource):
"""
try:
if resource is not None:
- if not resource:
- raise InvalidJID('Resource must not be 0 bytes')
-
resource = resourceprep(resource)
if not resource:
raise InvalidJID('Resource must not be 0 bytes')
+ if len(resource) > 1023:
+ raise InvalidJID('Resource must be less than 1024 bytes')
return resource
except stringprep_profiles.StringPrepError:
raise InvalidJID('Invalid resource')
diff --git a/sleekxmpp/util/stringprep_profiles.py b/sleekxmpp/util/stringprep_profiles.py
index a75bb9dd..6844c9ac 100644
--- a/sleekxmpp/util/stringprep_profiles.py
+++ b/sleekxmpp/util/stringprep_profiles.py
@@ -77,6 +77,9 @@ def check_bidi(data):
character MUST be the first character of the string, and a
RandALCat character MUST be the last character of the string.
"""
+ if not data:
+ return data
+
has_lcat = False
has_randal = False
diff --git a/tests/test_jid.py b/tests/test_jid.py
index 7b800520..aeb635a1 100644
--- a/tests/test_jid.py
+++ b/tests/test_jid.py
@@ -1,5 +1,5 @@
from sleekxmpp.test import *
-from sleekxmpp import JID
+from sleekxmpp import JID, InvalidJID
class TestJIDClass(SleekTest):
@@ -137,5 +137,146 @@ class TestJIDClass(SleekTest):
self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal")
self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal")
+ def testZeroLengthDomain(self):
+ self.assertRaises(InvalidJID, JID, domain='')
+ self.assertRaises(InvalidJID, JID, 'user@/resource')
+
+ def testZeroLengthLocalPart(self):
+ self.assertRaises(InvalidJID, JID, local='', domain='test.com')
+ self.assertRaises(InvalidJID, JID, '@/test.com')
+
+ def testZeroLengthResource(self):
+ self.assertRaises(InvalidJID, JID, domain='test.com', resource='')
+ self.assertRaises(InvalidJID, JID, 'test.com/')
+
+ def test1023LengthDomain(self):
+ domain = ('a.' * 509) + 'a.com'
+ jid1 = JID(domain=domain)
+ jid2 = JID('user@%s/resource' % domain)
+
+ def test1023LengthLocalPart(self):
+ local = 'a' * 1023
+ jid1 = JID(local=local, domain='test.com')
+ jid2 = JID('%s@test.com' % local)
+
+ def test1023LengthResource(self):
+ resource = 'r' * 1023
+ jid1 = JID(domain='test.com', resource=resource)
+ jid2 = JID('test.com/%s' % resource)
+
+ def test1024LengthDomain(self):
+ domain = ('a.' * 509) + 'aa.com'
+ self.assertRaises(InvalidJID, JID, domain=domain)
+ self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
+
+ def test1024LengthLocalPart(self):
+ local = 'a' * 1024
+ self.assertRaises(InvalidJID, JID, local=local, domain='test.com')
+ self.assertRaises(InvalidJID, JID, '%s@/test.com' % local)
+
+ def test1024LengthResource(self):
+ resource = 'r' * 1024
+ self.assertRaises(InvalidJID, JID, domain='test.com', resource=resource)
+ self.assertRaises(InvalidJID, JID, 'test.com/%s' % resource)
+
+ def testTooLongDomainLabel(self):
+ domain = ('a' * 64) + '.com'
+ self.assertRaises(InvalidJID, JID, domain=domain)
+ self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
+
+ def testDomainEmptyLabel(self):
+ domain = 'aaa..bbb.com'
+ self.assertRaises(InvalidJID, JID, domain=domain)
+ self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
+
+ def testDomainIPv4(self):
+ domain = '127.0.0.1'
+ jid1 = JID(domain=domain)
+ jid2 = JID('user@%s/resource' % domain)
+
+ def testDomainIPv6(self):
+ domain = '[::1]'
+ jid1 = JID(domain=domain)
+ jid2 = JID('user@%s/resource' % domain)
+
+ def testDomainInvalidIPv6NoBrackets(self):
+ domain = '::1'
+ jid1 = JID(domain=domain)
+ jid2 = JID('user@%s/resource' % domain)
+
+ self.assertEqual(jid1.domain, '[::1]')
+ self.assertEqual(jid2.domain, '[::1]')
+
+ def testDomainInvalidIPv6MissingBracket(self):
+ domain = '[::1'
+ jid1 = JID(domain=domain)
+ jid2 = JID('user@%s/resource' % domain)
+
+ self.assertEqual(jid1.domain, '[::1]')
+ self.assertEqual(jid2.domain, '[::1]')
+
+ def testDomainWithPort(self):
+ domain = 'example.com:5555'
+ self.assertRaises(InvalidJID, JID, domain=domain)
+ self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
+
+ def testDomainWithTrailingDot(self):
+ domain = 'example.com.'
+ jid1 = JID(domain=domain)
+ jid2 = JID('user@%s/resource' % domain)
+
+ self.assertEqual(jid1.domain, 'example.com')
+ self.assertEqual(jid2.domain, 'example.com')
+
+ def testDomainWithDashes(self):
+ domain = 'example.com-'
+ self.assertRaises(InvalidJID, JID, domain=domain)
+ self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
+
+ domain = '-example.com'
+ self.assertRaises(InvalidJID, JID, domain=domain)
+ self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
+
+ def testACEDomain(self):
+ domain = 'xn--bcher-kva.ch'
+ jid1 = JID(domain=domain)
+ jid2 = JID('user@%s/resource' % domain)
+
+ self.assertEqual(jid1.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
+ self.assertEqual(jid2.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
+
+ def testJIDEscapeExistingSequences(self):
+ jid = JID(local='blah\\foo\\20bar', domain='example.com')
+ self.assertEqual(jid.local, 'blah\\foo\\5c20bar')
+
+ def testJIDEscape(self):
+ jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
+ domain='example.com')
+ self.assertEqual(jid.local, r'here\27s_a_wild_\26_\2fcr%zy\2f_address_for\3a\3cwv\3e(\22IMPS\22)')
+
+ def testJIDUnescape(self):
+ jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
+ domain='example.com')
+ ujid = jid.unescape()
+ self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")')
+
+ jid = JID(local='blah\\foo\\20bar', domain='example.com')
+ ujid = jid.unescape()
+ self.assertEqual(ujid.local, 'blah\\foo\\20bar')
+
+ def testStartOrEndWithEscapedSpaces(self):
+ local = ' foo'
+ self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
+ self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
+
+ local = 'bar '
+ self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
+ self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
+
+ # Need more input for these cases. A JID starting with \20 *is* valid
+ # according to RFC 6122, but is not according to XEP-0106.
+ #self.assertRaises(InvalidJID, JID, '%s@example.com' % '\\20foo2')
+ #self.assertRaises(InvalidJID, JID, '%s@example.com' % 'bar2\\20')
+
suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass)