summaryrefslogtreecommitdiff
path: root/slixmpp/plugins/xep_0065/socks5.py
diff options
context:
space:
mode:
Diffstat (limited to 'slixmpp/plugins/xep_0065/socks5.py')
-rw-r--r--slixmpp/plugins/xep_0065/socks5.py265
1 files changed, 265 insertions, 0 deletions
diff --git a/slixmpp/plugins/xep_0065/socks5.py b/slixmpp/plugins/xep_0065/socks5.py
new file mode 100644
index 00000000..54267b32
--- /dev/null
+++ b/slixmpp/plugins/xep_0065/socks5.py
@@ -0,0 +1,265 @@
+'''Pure asyncio implementation of RFC 1928 - SOCKS Protocol Version 5.'''
+
+import asyncio
+import enum
+import logging
+import socket
+import struct
+
+from slixmpp.stringprep import punycode, StringprepError
+
+
+log = logging.getLogger(__name__)
+
+
+class ProtocolMismatch(Exception):
+ '''We only implement SOCKS5, no other version or protocol.'''
+
+
+class ProtocolError(Exception):
+ '''Some protocol error.'''
+
+
+class MethodMismatch(Exception):
+ '''The server answered with a method we didn’t ask for.'''
+
+
+class MethodUnacceptable(Exception):
+ '''None of our methods is supported by the server.'''
+
+
+class AddressTypeUnacceptable(Exception):
+ '''The address type (ATYP) field isn’t one of IPv4, IPv6 or domain name.'''
+
+
+class ReplyError(Exception):
+ '''The server answered with an error.'''
+
+ possible_values = (
+ "succeeded",
+ "general SOCKS server failure",
+ "connection not allowed by ruleset",
+ "Network unreachable",
+ "Host unreachable",
+ "Connection refused",
+ "TTL expired",
+ "Command not supported",
+ "Address type not supported",
+ "Unknown error")
+
+ def __init__(self, result):
+ if result < 9:
+ Exception.__init__(self, self.possible_values[result])
+ else:
+ Exception.__init__(self, self.possible_values[9])
+
+
+class Method(enum.IntEnum):
+ '''Known methods for a SOCKS5 session.'''
+ none = 0
+ gssapi = 1
+ password = 2
+ # Methods 3 to 127 are reserved by IANA.
+ # Methods 128 to 254 are reserved for private use.
+ unacceptable = 255
+ not_yet_selected = -1
+
+
+class Command(enum.IntEnum):
+ '''Existing commands for requests.'''
+ connect = 1
+ bind = 2
+ udp_associate = 3
+
+
+class AddressType(enum.IntEnum):
+ '''Existing address types.'''
+ ipv4 = 1
+ domain = 3
+ ipv6 = 4
+
+
+class Socks5Protocol(asyncio.Protocol):
+ '''This implements SOCKS5 as an asyncio protocol.'''
+
+ def __init__(self, dest_addr, dest_port, event):
+ self.methods = {Method.none}
+ self.selected_method = Method.not_yet_selected
+ self.transport = None
+ self.dest = (dest_addr, dest_port)
+ self.connected = asyncio.Future()
+ self.event = event
+ self.paused = asyncio.Future()
+ self.paused.set_result(None)
+
+ def register_method(self, method):
+ '''Register a SOCKS5 method.'''
+ self.methods.add(method)
+
+ def unregister_method(self, method):
+ '''Unregister a SOCKS5 method.'''
+ self.methods.remove(method)
+
+ def connection_made(self, transport):
+ '''Called when the connection to the SOCKS5 server is established.'''
+
+ log.debug('SOCKS5 connection established.')
+
+ self.transport = transport
+ self._send_methods()
+
+ def data_received(self, data):
+ '''Called when we received some data from the SOCKS5 server.'''
+
+ log.debug('SOCKS5 message received.')
+
+ # If we are already connected, this is a data packet.
+ if self.connected.done():
+ return self.event('socks5_data', data)
+
+ # Every SOCKS5 message starts with the protocol version.
+ if data[0] != 5:
+ raise ProtocolMismatch()
+
+ # Then select the correct handler for the data we just received.
+ if self.selected_method == Method.not_yet_selected:
+ self._handle_method(data)
+ else:
+ self._handle_connect(data)
+
+ def connection_lost(self, exc):
+ log.debug('SOCKS5 connection closed.')
+ self.event('socks5_closed', exc)
+
+ def pause_writing(self):
+ self.paused = asyncio.Future()
+
+ def resume_writing(self):
+ self.paused.set_result(None)
+
+ def write(self, data):
+ yield from self.paused
+ self.transport.write(data)
+
+ def _send_methods(self):
+ '''Send the methods request, first thing a client should do.'''
+
+ # Create the buffer for our request.
+ request = bytearray(len(self.methods) + 2)
+
+ # Protocol version.
+ request[0] = 5
+
+ # Number of methods to send.
+ request[1] = len(self.methods)
+
+ # List every method we support.
+ for i, method in enumerate(self.methods):
+ request[i + 2] = method
+
+ # Send the request.
+ self.transport.write(request)
+
+ def _send_request(self, command):
+ '''Send a request, should be done after having negociated a method.'''
+
+ # Encode the destination address to embed it in our request.
+ # We need to do that first because its length is variable.
+ address, port = self.dest
+ addr = self._encode_addr(address)
+
+ # Create the buffer for our request.
+ request = bytearray(5 + len(addr))
+
+ # Protocol version.
+ request[0] = 5
+
+ # Specify the command we want to use.
+ request[1] = command
+
+ # request[2] is reserved, keeping it at 0.
+
+ # Add our destination address and port.
+ request[3:3+len(addr)] = addr
+ request[-2:] = struct.pack('>H', port)
+
+ # Send the request.
+ log.debug('SOCKS5 message sent.')
+ self.transport.write(request)
+
+ def _handle_method(self, data):
+ '''Handle a method reply from the server.'''
+
+ if len(data) != 2:
+ raise ProtocolError()
+ selected_method = data[1]
+ if selected_method not in self.methods:
+ raise MethodMismatch()
+ if selected_method == Method.unacceptable:
+ raise MethodUnacceptable()
+ self.selected_method = selected_method
+ self._send_request(Command.connect)
+
+ def _handle_connect(self, data):
+ '''Handle a connect reply from the server.'''
+
+ try:
+ addr, port = self._parse_result(data)
+ except ReplyError as exception:
+ self.connected.set_exception(exception)
+ self.connected.set_result((addr, port))
+ self.event('socks5_connected', (addr, port))
+
+ def _parse_result(self, data):
+ '''Parse a reply from the server.'''
+
+ result = data[1]
+ if result != 0:
+ raise ReplyError(result)
+ addr = self._parse_addr(data[3:-2])
+ port = struct.unpack('>H', data[-2:])[0]
+ return (addr, port)
+
+ @staticmethod
+ def _parse_addr(addr):
+ '''Parse an address (IP or domain) from a bytestream.'''
+
+ addr_type = addr[0]
+ if addr_type == AddressType.ipv6:
+ try:
+ return socket.inet_ntop(socket.AF_INET6, addr[1:])
+ except ValueError as e:
+ raise AddressTypeUnacceptable(e)
+ if addr_type == AddressType.ipv4:
+ try:
+ return socket.inet_ntop(socket.AF_INET, addr[1:])
+ except ValueError as e:
+ raise AddressTypeUnacceptable(e)
+ if addr_type == AddressType.domain:
+ length = addr[1]
+ address = addr[2:]
+ if length != len(address):
+ raise Exception('Size mismatch')
+ return address.decode()
+ raise AddressTypeUnacceptable(addr_type)
+
+ @staticmethod
+ def _encode_addr(addr):
+ '''Encode an address (IP or domain) into a bytestream.'''
+
+ try:
+ ipv6 = socket.inet_pton(socket.AF_INET6, addr)
+ return b'\x04' + ipv6
+ except OSError:
+ pass
+ try:
+ ipv4 = socket.inet_aton(addr)
+ return b'\x01' + ipv4
+ except OSError:
+ pass
+ try:
+ domain = punycode(addr)
+ return b'\x03' + bytes([len(domain)]) + domain
+ except StringprepError:
+ pass
+ raise Exception('Err…')