summaryrefslogtreecommitdiff
path: root/slixmpp/plugins/xep_0065/socks5.py
blob: 54267b32e50ce6dd100d45973263b22c084b609e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
'''Pure asyncio implementation of RFC 1928 - SOCKS Protocol Version 5.'''

import asyncio
import enum
import logging
import socket
import struct

from slixmpp.stringprep import punycode, StringprepError


log = logging.getLogger(__name__)


class ProtocolMismatch(Exception):
    '''We only implement SOCKS5, no other version or protocol.'''


class ProtocolError(Exception):
    '''Some protocol error.'''


class MethodMismatch(Exception):
    '''The server answered with a method we didn’t ask for.'''


class MethodUnacceptable(Exception):
    '''None of our methods is supported by the server.'''


class AddressTypeUnacceptable(Exception):
    '''The address type (ATYP) field isn’t one of IPv4, IPv6 or domain name.'''


class ReplyError(Exception):
    '''The server answered with an error.'''

    possible_values = (
        "succeeded",
        "general SOCKS server failure",
        "connection not allowed by ruleset",
        "Network unreachable",
        "Host unreachable",
        "Connection refused",
        "TTL expired",
        "Command not supported",
        "Address type not supported",
        "Unknown error")

    def __init__(self, result):
        if result < 9:
            Exception.__init__(self, self.possible_values[result])
        else:
            Exception.__init__(self, self.possible_values[9])


class Method(enum.IntEnum):
    '''Known methods for a SOCKS5 session.'''
    none = 0
    gssapi = 1
    password = 2
    # Methods 3 to 127 are reserved by IANA.
    # Methods 128 to 254 are reserved for private use.
    unacceptable = 255
    not_yet_selected = -1


class Command(enum.IntEnum):
    '''Existing commands for requests.'''
    connect = 1
    bind = 2
    udp_associate = 3


class AddressType(enum.IntEnum):
    '''Existing address types.'''
    ipv4 = 1
    domain = 3
    ipv6 = 4


class Socks5Protocol(asyncio.Protocol):
    '''This implements SOCKS5 as an asyncio protocol.'''

    def __init__(self, dest_addr, dest_port, event):
        self.methods = {Method.none}
        self.selected_method = Method.not_yet_selected
        self.transport = None
        self.dest = (dest_addr, dest_port)
        self.connected = asyncio.Future()
        self.event = event
        self.paused = asyncio.Future()
        self.paused.set_result(None)

    def register_method(self, method):
        '''Register a SOCKS5 method.'''
        self.methods.add(method)

    def unregister_method(self, method):
        '''Unregister a SOCKS5 method.'''
        self.methods.remove(method)

    def connection_made(self, transport):
        '''Called when the connection to the SOCKS5 server is established.'''

        log.debug('SOCKS5 connection established.')

        self.transport = transport
        self._send_methods()

    def data_received(self, data):
        '''Called when we received some data from the SOCKS5 server.'''

        log.debug('SOCKS5 message received.')

        # If we are already connected, this is a data packet.
        if self.connected.done():
            return self.event('socks5_data', data)

        # Every SOCKS5 message starts with the protocol version.
        if data[0] != 5:
            raise ProtocolMismatch()

        # Then select the correct handler for the data we just received.
        if self.selected_method == Method.not_yet_selected:
            self._handle_method(data)
        else:
            self._handle_connect(data)

    def connection_lost(self, exc):
        log.debug('SOCKS5 connection closed.')
        self.event('socks5_closed', exc)

    def pause_writing(self):
        self.paused = asyncio.Future()

    def resume_writing(self):
        self.paused.set_result(None)

    def write(self, data):
        yield from self.paused
        self.transport.write(data)

    def _send_methods(self):
        '''Send the methods request, first thing a client should do.'''

        # Create the buffer for our request.
        request = bytearray(len(self.methods) + 2)

        # Protocol version.
        request[0] = 5

        # Number of methods to send.
        request[1] = len(self.methods)

        # List every method we support.
        for i, method in enumerate(self.methods):
            request[i + 2] = method

        # Send the request.
        self.transport.write(request)

    def _send_request(self, command):
        '''Send a request, should be done after having negociated a method.'''

        # Encode the destination address to embed it in our request.
        # We need to do that first because its length is variable.
        address, port = self.dest
        addr = self._encode_addr(address)

        # Create the buffer for our request.
        request = bytearray(5 + len(addr))

        # Protocol version.
        request[0] = 5

        # Specify the command we want to use.
        request[1] = command

        # request[2] is reserved, keeping it at 0.

        # Add our destination address and port.
        request[3:3+len(addr)] = addr
        request[-2:] = struct.pack('>H', port)

        # Send the request.
        log.debug('SOCKS5 message sent.')
        self.transport.write(request)

    def _handle_method(self, data):
        '''Handle a method reply from the server.'''

        if len(data) != 2:
            raise ProtocolError()
        selected_method = data[1]
        if selected_method not in self.methods:
            raise MethodMismatch()
        if selected_method == Method.unacceptable:
            raise MethodUnacceptable()
        self.selected_method = selected_method
        self._send_request(Command.connect)

    def _handle_connect(self, data):
        '''Handle a connect reply from the server.'''

        try:
            addr, port = self._parse_result(data)
        except ReplyError as exception:
            self.connected.set_exception(exception)
        self.connected.set_result((addr, port))
        self.event('socks5_connected', (addr, port))

    def _parse_result(self, data):
        '''Parse a reply from the server.'''

        result = data[1]
        if result != 0:
            raise ReplyError(result)
        addr = self._parse_addr(data[3:-2])
        port = struct.unpack('>H', data[-2:])[0]
        return (addr, port)

    @staticmethod
    def _parse_addr(addr):
        '''Parse an address (IP or domain) from a bytestream.'''

        addr_type = addr[0]
        if addr_type == AddressType.ipv6:
            try:
                return socket.inet_ntop(socket.AF_INET6, addr[1:])
            except ValueError as e:
                raise AddressTypeUnacceptable(e)
        if addr_type == AddressType.ipv4:
            try:
                return socket.inet_ntop(socket.AF_INET, addr[1:])
            except ValueError as e:
                raise AddressTypeUnacceptable(e)
        if addr_type == AddressType.domain:
            length = addr[1]
            address = addr[2:]
            if length != len(address):
                raise Exception('Size mismatch')
            return address.decode()
        raise AddressTypeUnacceptable(addr_type)

    @staticmethod
    def _encode_addr(addr):
        '''Encode an address (IP or domain) into a bytestream.'''

        try:
            ipv6 = socket.inet_pton(socket.AF_INET6, addr)
            return b'\x04' + ipv6
        except OSError:
            pass
        try:
            ipv4 = socket.inet_aton(addr)
            return b'\x01' + ipv4
        except OSError:
            pass
        try:
            domain = punycode(addr)
            return b'\x03' + bytes([len(domain)]) + domain
        except StringprepError:
            pass
        raise Exception('Err…')