summaryrefslogtreecommitdiff
path: root/sleekxmpp/util/stringprep_profiles.py
blob: a75bb9dd8f77e0266a66e94cac2fc620a8cd11ce (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import unicode_literals

import sys
import stringprep
import unicodedata


class StringPrepError(UnicodeError):
    pass


def to_unicode(data):
    if sys.version_info < (3, 0):
        return unicode(data)
    else:
        return str(data)


def b1_mapping(char):
    return '' if stringprep.in_table_c12(char) else None


def c12_mapping(char):
    return ' ' if stringprep.in_table_c12(char) else None


def map_input(data, tables=None):
    """
    Each character in the input stream MUST be checked against
    a mapping table.
    """
    result = []
    for char in data:
        replacement = None

        for mapping in tables:
            replacement = mapping(char)
            if replacement is not None:
                break

        if replacement is None:
            replacement = char
        result.append(replacement)
    return ''.join(result)


def normalize(data, nfkc=True):
    """
    A profile can specify one of two options for Unicode normalization:
        - no normalization
        - Unicode normalization with form KC
    """
    if nfkc:
        data = unicodedata.normalize('NFKC', data)
    return data


def prohibit_output(data, tables=None):
    """
    Before the text can be emitted, it MUST be checked for prohibited
    code points.
    """
    for char in data:
        for check in tables:
            if check(char):
                raise StringPrepError("Prohibited code point: %s" % char)


def check_bidi(data):
    """
    1) The characters in section 5.8 MUST be prohibited.

    2) If a string contains any RandALCat character, the string MUST NOT
       contain any LCat character.

    3) If a string contains any RandALCat character, a RandALCat
       character MUST be the first character of the string, and a
       RandALCat character MUST be the last character of the string.
    """
    has_lcat = False
    has_randal = False

    for c in data:
        if stringprep.in_table_c8(c):
            raise StringPrepError("BIDI violation: seciton 6 (1)")
        if stringprep.in_table_d1(c):
            has_randal = True
        elif stringprep.in_table_d2(c):
            has_lcat = True

    if has_randal and has_lcat:
        raise StringPrepError("BIDI violation: section 6 (2)")

    first_randal = stringprep.in_table_d1(data[0])
    last_randal = stringprep.in_table_d1(data[-1])
    if has_randal and not (first_randal and last_randal):
        raise StringPrepError("BIDI violation: section 6 (3)")


def create(nfkc=True, bidi=True, mappings=None,
           prohibited=None, unassigned=None):
    def profile(data, query=False):
        try:
            data = to_unicode(data)
        except UnicodeError:
            raise StringPrepError

        data = map_input(data, mappings)
        data = normalize(data, nfkc)
        prohibit_output(data, prohibited)
        if bidi:
            check_bidi(data)
        if query and unassigned:
            check_unassigned(data, unassigned)
        return data
    return profile