From c2f6f077762282d311a6f876f94cc1a4eb9e805f Mon Sep 17 00:00:00 2001 From: Florent Le Coz Date: Sun, 20 Jul 2014 20:46:03 +0200 Subject: Make xmlstream use an asyncio loop Scheduled events, connection, TLS handshake (with STARTTLS), read and write on the socket are all done using only asyncio. A lot of threads, and thread-related (and thus useless) things still remain. This is only a first step. --- slixmpp/basexmpp.py | 27 +- slixmpp/clientxmpp.py | 9 +- slixmpp/features/feature_bind/bind.py | 7 +- slixmpp/features/feature_mechanisms/mechanisms.py | 4 +- slixmpp/features/feature_session/session.py | 3 +- slixmpp/features/feature_starttls/starttls.py | 3 +- slixmpp/xmlstream/__init__.py | 3 +- slixmpp/xmlstream/scheduler.py | 250 ------ slixmpp/xmlstream/xmlstream.py | 922 +++++----------------- 9 files changed, 221 insertions(+), 1007 deletions(-) delete mode 100644 slixmpp/xmlstream/scheduler.py diff --git a/slixmpp/basexmpp.py b/slixmpp/basexmpp.py index 03557218..2d83385f 100644 --- a/slixmpp/basexmpp.py +++ b/slixmpp/basexmpp.py @@ -213,37 +213,12 @@ class BaseXMPP(XMLStream): log.warning('Legacy XMPP 0.9 protocol detected.') self.event('legacy_protocol') - def process(self, *args, **kwargs): - """Initialize plugins and begin processing the XML stream. - - The number of threads used for processing stream events is determined - by :data:`HANDLER_THREADS`. - - :param bool block: If ``False``, then event dispatcher will run - in a separate thread, allowing for the stream to be - used in the background for another application. - Otherwise, ``process(block=True)`` blocks the current - thread. Defaults to ``False``. - :param bool threaded: **DEPRECATED** - If ``True``, then event dispatcher will run - in a separate thread, allowing for the stream to be - used in the background for another application. - Defaults to ``True``. This does **not** mean that no - threads are used at all if ``threaded=False``. - - Regardless of these threading options, these threads will - always exist: - - - The event queue processor - - The send queue processor - - The scheduler - """ + def init_plugins(self, *args, **kwargs): for name in self.plugin: if not hasattr(self.plugin[name], 'post_inited'): if hasattr(self.plugin[name], 'post_init'): self.plugin[name].post_init() self.plugin[name].post_inited = True - return XMLStream.process(self, *args, **kwargs) def register_plugin(self, plugin, pconfig={}, module=None): """Register and configure a plugin for use in this stream. diff --git a/slixmpp/clientxmpp.py b/slixmpp/clientxmpp.py index 55c82f82..ae9010d4 100644 --- a/slixmpp/clientxmpp.py +++ b/slixmpp/clientxmpp.py @@ -128,8 +128,8 @@ class ClientXMPP(BaseXMPP): def password(self, value): self.credentials['password'] = value - def connect(self, address=tuple(), reattempt=True, - use_tls=True, use_ssl=False): + def connect(self, address=tuple(), use_ssl=False, + force_starttls=True, disable_starttls=False): """Connect to the XMPP server. When no address is given, a SRV lookup for the server will @@ -155,9 +155,8 @@ class ClientXMPP(BaseXMPP): address = (self.boundjid.host, 5222) self.dns_service = 'xmpp-client' - return XMLStream.connect(self, address[0], address[1], - use_tls=use_tls, use_ssl=use_ssl, - reattempt=reattempt) + return XMLStream.connect(self, address[0], address[1], use_ssl=use_ssl, + force_starttls=force_starttls, disable_starttls=disable_starttls) def register_feature(self, name, handler, restart=False, order=5000): """Register a stream feature handler. diff --git a/slixmpp/features/feature_bind/bind.py b/slixmpp/features/feature_bind/bind.py index ac69ee77..f636abf9 100644 --- a/slixmpp/features/feature_bind/bind.py +++ b/slixmpp/features/feature_bind/bind.py @@ -42,13 +42,16 @@ class FeatureBind(BasePlugin): features -- The stream features stanza. """ log.debug("Requesting resource: %s", self.xmpp.requested_jid.resource) + self.features = features iq = self.xmpp.Iq() iq['type'] = 'set' iq.enable('bind') if self.xmpp.requested_jid.resource: iq['bind']['resource'] = self.xmpp.requested_jid.resource - response = iq.send(now=True) + iq.send(block=False, callback=self._on_bind_response) + + def _on_bind_response(self, response): self.xmpp.boundjid = JID(response['bind']['jid'], cache_lock=True) self.xmpp.bound = True self.xmpp.event('session_bind', self.xmpp.boundjid, direct=True) @@ -58,7 +61,7 @@ class FeatureBind(BasePlugin): log.info("JID set to: %s", self.xmpp.boundjid.full) - if 'session' not in features['features']: + if 'session' not in self.features['features']: log.debug("Established Session") self.xmpp.sessionstarted = True self.xmpp.session_started_event.set() diff --git a/slixmpp/features/feature_mechanisms/mechanisms.py b/slixmpp/features/feature_mechanisms/mechanisms.py index 663bfe57..3cbb83f2 100644 --- a/slixmpp/features/feature_mechanisms/mechanisms.py +++ b/slixmpp/features/feature_mechanisms/mechanisms.py @@ -233,7 +233,9 @@ class FeatureMechanisms(BasePlugin): self.xmpp.authenticated = True self.xmpp.features.add('mechanisms') self.xmpp.event('auth_success', stanza, direct=True) - raise RestartStream() + # Restart the stream + self.xmpp.init_parser() + self.xmpp.send_raw(self.xmpp.stream_header) def _handle_fail(self, stanza): """SASL authentication failed. Disconnect and shutdown.""" diff --git a/slixmpp/features/feature_session/session.py b/slixmpp/features/feature_session/session.py index c2694a9f..08f7480f 100644 --- a/slixmpp/features/feature_session/session.py +++ b/slixmpp/features/feature_session/session.py @@ -44,8 +44,9 @@ class FeatureSession(BasePlugin): iq = self.xmpp.Iq() iq['type'] = 'set' iq.enable('session') - iq.send(now=True) + iq.send(block=False, callback=self._on_start_session_response) + def _on_start_session_response(self, response): self.xmpp.features.add('session') log.debug("Established Session") diff --git a/slixmpp/features/feature_starttls/starttls.py b/slixmpp/features/feature_starttls/starttls.py index 4b9dd60b..a05f755b 100644 --- a/slixmpp/features/feature_starttls/starttls.py +++ b/slixmpp/features/feature_starttls/starttls.py @@ -52,7 +52,7 @@ class FeatureSTARTTLS(BasePlugin): # We have already negotiated TLS, but the server is # offering it again, against spec. return False - elif not self.xmpp.use_tls: + elif self.xmpp.disable_starttls: return False else: self.xmpp.send(features['starttls'], now=True) @@ -63,4 +63,3 @@ class FeatureSTARTTLS(BasePlugin): log.debug("Starting TLS") if self.xmpp.start_tls(): self.xmpp.features.add('starttls') - raise RestartStream() diff --git a/slixmpp/xmlstream/__init__.py b/slixmpp/xmlstream/__init__.py index 6b04d35c..fa192265 100644 --- a/slixmpp/xmlstream/__init__.py +++ b/slixmpp/xmlstream/__init__.py @@ -7,13 +7,12 @@ """ from slixmpp.jid import JID -from slixmpp.xmlstream.scheduler import Scheduler from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase, ET from slixmpp.xmlstream.stanzabase import register_stanza_plugin from slixmpp.xmlstream.tostring import tostring from slixmpp.xmlstream.xmlstream import XMLStream, RESPONSE_TIMEOUT from slixmpp.xmlstream.xmlstream import RestartStream -__all__ = ['JID', 'Scheduler', 'StanzaBase', 'ElementBase', +__all__ = ['JID', 'StanzaBase', 'ElementBase', 'ET', 'StateMachine', 'tostring', 'XMLStream', 'RESPONSE_TIMEOUT', 'RestartStream'] diff --git a/slixmpp/xmlstream/scheduler.py b/slixmpp/xmlstream/scheduler.py deleted file mode 100644 index 137230b6..00000000 --- a/slixmpp/xmlstream/scheduler.py +++ /dev/null @@ -1,250 +0,0 @@ -# -*- coding: utf-8 -*- -""" - slixmpp.xmlstream.scheduler - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - This module provides a task scheduler that works better - with Slixmpp's threading usage than the stock version. - - Part of Slixmpp: The Slick XMPP Library - - :copyright: (c) 2011 Nathanael C. Fritz - :license: MIT, see LICENSE for more details -""" - -import time -import threading -import logging -import itertools - -from slixmpp.util import Queue, QueueEmpty - - -#: The time in seconds to wait for events from the event queue, and also the -#: time between checks for the process stop signal. -WAIT_TIMEOUT = 1.0 - - -log = logging.getLogger(__name__) - - -class Task(object): - - """ - A scheduled task that will be executed by the scheduler - after a given time interval has passed. - - :param string name: The name of the task. - :param int seconds: The number of seconds to wait before executing. - :param callback: The function to execute. - :param tuple args: The arguments to pass to the callback. - :param dict kwargs: The keyword arguments to pass to the callback. - :param bool repeat: Indicates if the task should repeat. - Defaults to ``False``. - :param pointer: A pointer to an event queue for queuing callback - execution instead of executing immediately. - """ - - def __init__(self, name, seconds, callback, args=None, - kwargs=None, repeat=False, qpointer=None): - #: The name of the task. - self.name = name - - #: The number of seconds to wait before executing. - self.seconds = seconds - - #: The function to execute once enough time has passed. - self.callback = callback - - #: The arguments to pass to :attr:`callback`. - self.args = args or tuple() - - #: The keyword arguments to pass to :attr:`callback`. - self.kwargs = kwargs or {} - - #: Indicates if the task should repeat after executing, - #: using the same :attr:`seconds` delay. - self.repeat = repeat - - #: The time when the task should execute next. - self.next = time.time() + self.seconds - - #: The main event queue, which allows for callbacks to - #: be queued for execution instead of executing immediately. - self.qpointer = qpointer - - def run(self): - """Execute the task's callback. - - If an event queue was supplied, place the callback in the queue; - otherwise, execute the callback immediately. - """ - if self.qpointer is not None: - self.qpointer.put(('schedule', self.callback, - self.args, self.kwargs, self.name)) - else: - self.callback(*self.args, **self.kwargs) - self.reset() - return self.repeat - - def reset(self): - """Reset the task's timer so that it will repeat.""" - self.next = time.time() + self.seconds - - -class Scheduler(object): - - """ - A threaded scheduler that allows for updates mid-execution unlike the - scheduler in the standard library. - - Based on: http://docs.python.org/library/sched.html#module-sched - - :param parentstop: An :class:`~threading.Event` to signal stopping - the scheduler. - """ - - def __init__(self, parentstop=None): - #: A queue for storing tasks - self.addq = Queue() - - #: A list of tasks in order of execution time. - self.schedule = [] - - #: If running in threaded mode, this will be the thread processing - #: the schedule. - self.thread = None - - #: A flag indicating that the scheduler is running. - self.run = False - - #: An :class:`~threading.Event` instance for signalling to stop - #: the scheduler. - self.stop = parentstop - - #: Lock for accessing the task queue. - self.schedule_lock = threading.RLock() - - #: The time in seconds to wait for events from the event queue, - #: and also the time between checks for the process stop signal. - self.wait_timeout = WAIT_TIMEOUT - - def process(self, threaded=True, daemon=False): - """Begin accepting and processing scheduled tasks. - - :param bool threaded: Indicates if the scheduler should execute - in its own thread. Defaults to ``True``. - """ - if threaded: - self.thread = threading.Thread(name='scheduler_process', - target=self._process) - self.thread.daemon = daemon - self.thread.start() - else: - self._process() - - def _process(self): - """Process scheduled tasks.""" - self.run = True - try: - while self.run and not self.stop.is_set(): - updated = False - if self.schedule: - wait = self.schedule[0].next - time.time() - else: - wait = self.wait_timeout - try: - if wait <= 0.0: - newtask = self.addq.get(False) - else: - newtask = None - while self.run and \ - not self.stop.is_set() and \ - newtask is None and \ - wait > 0: - try: - newtask = self.addq.get(True, min(wait, self.wait_timeout)) - except QueueEmpty: # Nothing to add, nothing to do. Check run flags and continue waiting. - wait -= self.wait_timeout - except QueueEmpty: # Time to run some tasks, and no new tasks to add. - self.schedule_lock.acquire() - # select only those tasks which are to be executed now - relevant = itertools.takewhile( - lambda task: time.time() >= task.next, self.schedule) - # run the tasks and keep the return value in a tuple - status = map(lambda task: (task, task.run()), relevant) - # remove non-repeating tasks - for task, doRepeat in status: - if not doRepeat: - try: - self.schedule.remove(task) - except ValueError: - pass - else: - # only need to resort tasks if a repeated task has - # been kept in the list. - updated = True - else: # Add new task - self.schedule_lock.acquire() - if newtask is not None: - self.schedule.append(newtask) - updated = True - finally: - if updated: - self.schedule.sort(key=lambda task: task.next) - self.schedule_lock.release() - except KeyboardInterrupt: - self.run = False - except SystemExit: - self.run = False - log.debug("Quitting Scheduler thread") - - def add(self, name, seconds, callback, args=None, - kwargs=None, repeat=False, qpointer=None): - """Schedule a new task. - - :param string name: The name of the task. - :param int seconds: The number of seconds to wait before executing. - :param callback: The function to execute. - :param tuple args: The arguments to pass to the callback. - :param dict kwargs: The keyword arguments to pass to the callback. - :param bool repeat: Indicates if the task should repeat. - Defaults to ``False``. - :param pointer: A pointer to an event queue for queuing callback - execution instead of executing immediately. - """ - try: - self.schedule_lock.acquire() - for task in self.schedule: - if task.name == name: - raise ValueError("Key %s already exists" % name) - - self.addq.put(Task(name, seconds, callback, args, - kwargs, repeat, qpointer)) - except: - raise - finally: - self.schedule_lock.release() - - def remove(self, name): - """Remove a scheduled task ahead of schedule, and without - executing it. - - :param string name: The name of the task to remove. - """ - try: - self.schedule_lock.acquire() - the_task = None - for task in self.schedule: - if task.name == name: - the_task = task - if the_task is not None: - self.schedule.remove(the_task) - except: - raise - finally: - self.schedule_lock.release() - - def quit(self): - """Shutdown the scheduler.""" - self.run = False diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 040f5096..5df656ad 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -14,6 +14,8 @@ from __future__ import with_statement, unicode_literals +import asyncio +import functools import base64 import copy import logging @@ -29,22 +31,17 @@ import uuid import errno from xml.parsers.expat import ExpatError +import xml.etree.ElementTree import slixmpp from slixmpp.util import Queue, QueueEmpty, safedict from slixmpp.thirdparty.statemachine import StateMachine -from slixmpp.xmlstream import Scheduler, tostring, cert +from slixmpp.xmlstream import tostring, cert from slixmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase from slixmpp.xmlstream.handler import Waiter, XMLCallback from slixmpp.xmlstream.matcher import MatchXMLMask from slixmpp.xmlstream.resolver import resolve, default_resolver -# In Python 2.x, file socket objects are broken. A patched socket -# wrapper is provided for this case in filesocket.py. -if sys.version_info < (3, 0): - from slixmpp.xmlstream.filesocket import FileSocket, Socket26 - - #: The time in seconds to wait before timing out waiting for response stanzas. RESPONSE_TIMEOUT = 30 @@ -115,6 +112,26 @@ class XMLStream(object): """ def __init__(self, socket=None, host='', port=0): + # The asyncio.Transport object provided by the connection_made() + # callback when we are connected + self.transport = None + + # The socket the is used internally by the transport object + self.socket = None + + self.parser = None + self.xml_depth = 0 + self.xml_root = None + + self.force_starttls = None + self.disable_starttls = None + + # A dict of {name: handle} + self.scheduled_events = {} + + self.ssl_context = ssl.create_default_context() + self.ssl_context.check_hostname = False + self.ssl_context.verify_mode = ssl.CERT_NONE #: Most XMPP servers support TLSv1, but OpenFire in particular #: does not work well with it. For OpenFire, set #: :attr:`ssl_version` to use ``SSLv23``:: @@ -197,26 +214,11 @@ class XMLStream(object): #: The desired, or actual, address of the connected server. self.address = (host, int(port)) - #: A file-like wrapper for the socket for use with the - #: :mod:`~xml.etree.ElementTree` module. - self.filesocket = None - self.set_socket(socket) - - if sys.version_info < (3, 0): - self.socket_class = Socket26 - else: - self.socket_class = Socket.socket - #: Enable connecting to the server directly over SSL, in #: particular when the service provides two ports: one for #: non-SSL traffic and another for SSL traffic. self.use_ssl = False - #: Enable connecting to the service without using SSL - #: immediately, but allow upgrading the connection later - #: to use SSL. - self.use_tls = False - #: If set to ``True``, attempt to connect through an HTTP #: proxy based on the settings in :attr:`proxy_config`. self.use_proxy = False @@ -287,17 +289,11 @@ class XMLStream(object): #: if the connection is terminated. self.end_session_on_disconnect = True - #: A queue of stream, custom, and scheduled events to be processed. - self.event_queue = Queue() - #: A queue of string data to be sent over the stream. self.send_queue = Queue() self.send_queue_lock = threading.Lock() self.send_lock = threading.RLock() - #: A :class:`~slixmpp.xmlstream.scheduler.Scheduler` instance for - #: executing callbacks in the future based on time delays. - self.scheduler = Scheduler(self.stop) self.__failed_send_stanza = None #: A mapping of XML namespaces to well-known prefixes. @@ -307,7 +303,6 @@ class XMLStream(object): self.__root_stanza = [] self.__handlers = [] self.__event_handlers = {} - self.__event_handlers_lock = threading.Lock() self.__filters = {'in': [], 'out': [], 'out_sync': []} self.__thread_count = 0 self.__thread_cond = threading.Condition() @@ -407,23 +402,27 @@ class XMLStream(object): return "%s%X" % (self._id_prefix, self._id) def connect(self, host='', port=0, use_ssl=False, - use_tls=True, reattempt=True): + force_starttls=True, disable_starttls=False): """Create a new socket and connect to the server. - Setting ``reattempt`` to ``True`` will cause connection - attempts to be made with an exponential backoff delay (max of - :attr:`reconnect_max_delay` which defaults to 10 minute) until a - successful connection is established. - :param host: The name of the desired server for the connection. :param port: Port to connect to on the server. :param use_ssl: Flag indicating if SSL should be used by connecting - directly to a port using SSL. - :param use_tls: Flag indicating if TLS should be used, allowing for - connecting to a port without using SSL immediately and - later upgrading the connection. - :param reattempt: Flag indicating if the socket should reconnect - after disconnections. + directly to a port using SSL. If it is False, the + connection will be upgraded to SSL/TLS later, using + STARTTLS. Only use this value for old servers that + have specific port for SSL/TLS + TODO fix the comment + :param force_starttls: If True, the connection will be aborted if + the server does not initiate a STARTTLS + negociation. If None, the connection will be + upgraded to TLS only if the server initiate + the STARTTLS negociation, otherwise it will + connect in clear. If False it will never + upgrade to TLS, even if the server provides + it. Use this for example if you’re on + localhost + """ self.stop.clear() @@ -434,212 +433,76 @@ class XMLStream(object): except (Socket.error, ssl.SSLError): self.default_domain = self.address[0] - # Respect previous SSL and TLS usage directives. + # Respect previous TLS usage. if use_ssl is not None: self.use_ssl = use_ssl - if use_tls is not None: - self.use_tls = use_tls - - # Repeatedly attempt to connect until a successful connection - # is established. - attempts = self.reconnect_max_attempts - connected = self.state.transition('disconnected', 'connected', - func=self._connect, - args=(reattempt,)) - while reattempt and not connected and not self.stop.is_set(): - connected = self.state.transition('disconnected', 'connected', - func=self._connect) - if not connected: - if attempts is not None: - attempts -= 1 - if attempts <= 0: - self.event('connection_failed', direct=True) - return False - return connected - - def _connect(self, reattempt=True): - self.scheduler.remove('Session timeout check') - - if self.reconnect_delay is None or not reattempt: - delay = 1.0 - else: - delay = min(self.reconnect_delay * 2, self.reconnect_max_delay) - delay = random.normalvariate(delay, delay * 0.1) - log.debug('Waiting %s seconds before connecting.', delay) - elapsed = 0 - try: - while elapsed < delay and not self.stop.is_set(): - time.sleep(0.1) - elapsed += 0.1 - except KeyboardInterrupt: - self.set_stop() - return False - except SystemExit: - self.set_stop() - return False - - if self.default_domain: - try: - host, address, port = self.pick_dns_answer(self.default_domain, - self.address[1]) - self.address = (address, port) - self._service_name = host - except StopIteration: - log.debug("No remaining DNS records to try.") - self.dns_answers = None - if reattempt: - self.reconnect_delay = delay - return False - - af = Socket.AF_INET - proto = 'IPv4' - if ':' in self.address[0]: - af = Socket.AF_INET6 - proto = 'IPv6' - try: - self.socket = self.socket_class(af, Socket.SOCK_STREAM) - except Socket.error: - log.debug("Could not connect using %s", proto) - return False - - self.configure_socket() - - if self.use_proxy: - connected = self._connect_proxy() - if not connected: - if reattempt: - self.reconnect_delay = delay - return False - - if self.use_ssl: - log.debug("Socket Wrapped for SSL") - if self.ca_certs is None: - cert_policy = ssl.CERT_NONE - else: - cert_policy = ssl.CERT_REQUIRED - - ssl_args = safedict({ - 'certfile': self.certfile, - 'keyfile': self.keyfile, - 'ca_certs': self.ca_certs, - 'cert_reqs': cert_policy, - 'do_handshake_on_connect': False - }) - - if sys.version_info >= (2, 7): - ssl_args['ciphers'] = self.ciphers - - ssl_socket = ssl.wrap_socket(self.socket, **ssl_args) - - if hasattr(self.socket, 'socket'): - # We are using a testing socket, so preserve the top - # layer of wrapping. - self.socket.socket = ssl_socket - else: - self.socket = ssl_socket + if force_starttls is not None: + self.force_starttls = force_starttls + if disable_starttls is not None: + self.disable_starttls = disable_starttls + + loop = asyncio.get_event_loop() + connect_routine = loop.create_connection(lambda: self, + self.address[0], + self.address[1], + ssl=self.use_ssl) + asyncio.async(connect_routine) + + def init_parser(self): + self.xml_depth = 0 + self.xml_root = None + self.parser = xml.etree.ElementTree.XMLPullParser(("start", "end")) + + def connection_made(self, transport): + self.transport = transport + self.socket = self.transport.get_extra_info("socket") + self.init_parser() + self.send_raw(self.stream_header) + + def data_received(self, data): + self.parser.feed(data) + for event, xml in self.parser.read_events(): + if event == 'start': + if self.xml_depth == 0: + # We have received the start of the root element. + self.xml_root = xml + log.debug('RECV: %s', tostring(self.xml_root, xmlns=self.default_ns, + stream=self, + top_level=True, + open_only=True)) + # Perform any stream initialization actions, such + # as handshakes. + self.stream_end_event.clear() + self.start_stream_handler(self.xml_root) - try: - if not self.use_proxy: - domain = self.address[0] - if ':' in domain: - domain = '[%s]' % domain - log.debug("Connecting to %s:%s", domain, self.address[1]) - self.socket.connect(self.address) - - if self.use_ssl: - try: - self.socket.do_handshake() - except (Socket.error, ssl.SSLError): - log.error('CERT: Invalid certificate trust chain.') - if not self.event_handled('ssl_invalid_chain'): - self.disconnect(self.auto_reconnect, - send_close=False) - else: - self.event('ssl_invalid_chain', direct=True) - return False - - self._der_cert = self.socket.getpeercert(binary_form=True) - pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) - log.debug('CERT: %s', pem_cert) - - self.event('ssl_cert', pem_cert, direct=True) + # We have a successful stream connection, so reset + # exponential backoff for new reconnect attempts. + self.reconnect_delay = 1.0 + self.xml_depth += 1 + if event == 'end': + self.xml_depth -= 1 + if self.xml_depth == 0: + # The stream's root element has closed, + # terminating the stream. + log.debug("End of stream recieved") + self.stream_end_event.set() + return False + elif self.xml_depth == 1: + # We only raise events for stanzas that are direct + # children of the root element. try: - cert.verify(self._expected_server_name, self._der_cert) - except cert.CertificateError as err: - if not self.event_handled('ssl_invalid_cert'): - log.error(err) - self.disconnect(send_close=False) - else: - self.event('ssl_invalid_cert', - pem_cert, - direct=True) - - self.set_socket(self.socket, ignore=True) - #this event is where you should set your application state - self.event('connected', direct=True) - return True - except (Socket.error, ssl.SSLError) as serr: - error_msg = "Could not connect to %s:%s. Socket Error #%s: %s" - self.event('socket_error', serr, direct=True) - domain = self.address[0] - if ':' in domain: - domain = '[%s]' % domain - log.error(error_msg, domain, self.address[1], - serr.errno, serr.strerror) - return False - - def _connect_proxy(self): - """Attempt to connect using an HTTP Proxy.""" - - # Extract the proxy address, and optional credentials - address = (self.proxy_config['host'], int(self.proxy_config['port'])) - cred = None - if self.proxy_config['username']: - username = self.proxy_config['username'] - password = self.proxy_config['password'] - - cred = '%s:%s' % (username, password) - if sys.version_info < (3, 0): - cred = bytes(cred) - else: - cred = bytes(cred, 'utf-8') - cred = base64.b64encode(cred).decode('utf-8') - - # Build the HTTP headers for connecting to the XMPP server - headers = ['CONNECT %s:%s HTTP/1.0' % self.address, - 'Host: %s:%s' % self.address, - 'Proxy-Connection: Keep-Alive', - 'Pragma: no-cache', - 'User-Agent: Slixmpp/%s' % slixmpp.__version__] - if cred: - headers.append('Proxy-Authorization: Basic %s' % cred) - headers = '\r\n'.join(headers) + '\r\n\r\n' + self.__spawn_event(xml) + except RestartStream: + return True + if self.xml_root is not None: + # Keep the root element empty of children to + # save on memory use. + self.xml_root.clear() - try: - log.debug("Connecting to proxy: %s:%s", *address) - self.socket.connect(address) - self.send_raw(headers, now=True) - resp = '' - while '\r\n\r\n' not in resp and not self.stop.is_set(): - resp += self.socket.recv(1024).decode('utf-8') - log.debug('RECV: %s', resp) - - lines = resp.split('\r\n') - if '200' not in lines[0]: - self.event('proxy_error', resp) - self.event('connection_failed', direct=True) - log.error('Proxy Error: %s', lines[0]) - return False - - # Proxy connection established, continue connecting - # with the XMPP server. - return True - except (Socket.error, ssl.SSLError) as serr: - error_msg = "Could not connect to %s:%s. Socket Error #%s: %s" - self.event('socket_error', serr, direct=True) - log.error(error_msg, self.address[0], self.address[1], - serr.errno, serr.strerror) - return False + def connection_lost(self): + self.parser = None + self.transport = None + self.socket = None def _session_timeout_check(self, event=None): """ @@ -684,59 +547,8 @@ class XMLStream(object): prevents error loops when trying to disconnect after a socket error. """ - self.state.transition('connected', 'disconnected', - wait=2.0, - func=self._disconnect, - args=(reconnect, wait, send_close)) - - def _disconnect(self, reconnect=False, wait=None, send_close=True): - if not reconnect: - self.auto_reconnect = False - - if self.end_session_on_disconnect or send_close: - self.event('session_end', direct=True) - - # Wait for the send queue to empty. - if wait is not None: - if wait: - self.send_queue.join() - elif self.disconnect_wait: - self.send_queue.join() - - # Clearing this event will pause the send loop. - self.session_started_event.clear() - - self.__failed_send_stanza = None - - # Send the end of stream marker. - if send_close: - self.send_raw(self.stream_footer, now=True) - - # Wait for confirmation that the stream was - # closed in the other direction. If we didn't - # send a stream footer we don't need to wait - # since the server won't know to respond. - if send_close: - log.info('Waiting for %s from server', self.stream_footer) - self.stream_end_event.wait(4) - else: - self.stream_end_event.set() - - if not self.auto_reconnect: - self.set_stop() - if self._disconnect_wait_for_threads: - self._wait_for_threads() - - try: - self.socket.shutdown(Socket.SHUT_RDWR) - self.socket.close() - self.filesocket.close() - except (Socket.error, ssl.SSLError) as serr: - self.event('socket_error', serr, direct=True) - finally: - #clear your application state - self.event('disconnected', direct=True) - return True + # TODO + pass def abort(self): self.session_started_event.clear() @@ -746,7 +558,6 @@ class XMLStream(object): try: self.socket.shutdown(Socket.SHUT_RDWR) self.socket.close() - self.filesocket.close() except Socket.error: pass self.state.transition_any(['connected', 'disconnected'], 'disconnected', func=lambda: True) @@ -755,61 +566,14 @@ class XMLStream(object): def reconnect(self, reattempt=True, wait=False, send_close=True): """Reset the stream's state and reconnect to the server.""" log.debug("reconnecting...") - if self.state.ensure('connected'): - self.state.transition('connected', 'disconnected', - wait=2.0, - func=self._disconnect, - args=(True, wait, send_close)) - - attempts = self.reconnect_max_attempts - - log.debug("connecting...") - connected = self.state.transition('disconnected', 'connected', - wait=2.0, - func=self._connect, - args=(reattempt,)) - while reattempt and not connected and not self.stop.is_set(): - connected = self.state.transition('disconnected', 'connected', - wait=2.0, func=self._connect) - connected = connected or self.state.ensure('connected') - if not connected: - if attempts is not None: - attempts -= 1 - if attempts <= 0: - self.event('connection_failed', direct=True) - return False - return connected - - def set_socket(self, socket, ignore=False): - """Set the socket to use for the stream. - - The filesocket will be recreated as well. - - :param socket: The new socket object to use. - :param bool ignore: If ``True``, don't set the connection - state to ``'connected'``. - """ - self.socket = socket - if socket is not None: - # ElementTree.iterparse requires a file. - # 0 buffer files have to be binary. - - # Use the correct fileobject type based on the Python - # version to work around a broken implementation in - # Python 2.x. - if sys.version_info < (3, 0): - self.filesocket = FileSocket(self.socket) - else: - self.filesocket = self.socket.makefile('rb', 0) - if not ignore: - self.state._set_state('connected') + self.connect() def configure_socket(self): """Set timeout and other options for self.socket. Meant to be overridden. """ - self.socket.settimeout(None) + pass def configure_dns(self, resolver, domain=None, port=None): """ @@ -834,68 +598,15 @@ class XMLStream(object): If the handshake is successful, the XML stream will need to be restarted. """ - log.info("Negotiating TLS") - ssl_versions = {3: 'TLS 1.0', 1: 'SSL 3', 2: 'SSL 2/3'} - log.info("Using SSL version: %s", ssl_versions[self.ssl_version]) - if self.ca_certs is None: - cert_policy = ssl.CERT_NONE - else: - cert_policy = ssl.CERT_REQUIRED - - ssl_args = safedict({ - 'certfile': self.certfile, - 'keyfile': self.keyfile, - 'ca_certs': self.ca_certs, - 'cert_reqs': cert_policy, - 'do_handshake_on_connect': False - }) - - if sys.version_info >= (2, 7): - ssl_args['ciphers'] = self.ciphers - - ssl_socket = ssl.wrap_socket(self.socket, **ssl_args); - - if hasattr(self.socket, 'socket'): - # We are using a testing socket, so preserve the top - # layer of wrapping. - self.socket.socket = ssl_socket - else: - self.socket = ssl_socket - - try: - self.socket.do_handshake() - except (Socket.error, ssl.SSLError): - log.error('CERT: Invalid certificate trust chain.') - if not self.event_handled('ssl_invalid_chain'): - self.disconnect(self.auto_reconnect, send_close=False) - else: - self._der_cert = self.socket.getpeercert(binary_form=True) - self.event('ssl_invalid_chain', direct=True) - return False - - self._der_cert = self.socket.getpeercert(binary_form=True) - pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) - log.debug('CERT: %s', pem_cert) - self.event('ssl_cert', pem_cert, direct=True) - - try: - cert.verify(self._expected_server_name, self._der_cert) - except cert.CertificateError as err: - if not self.event_handled('ssl_invalid_cert'): - log.error(err) - self.disconnect(self.auto_reconnect, send_close=False) - else: - self.event('ssl_invalid_cert', pem_cert, direct=True) - - self.set_socket(self.socket) - return True + loop = asyncio.get_event_loop() + ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context, + sock=self.socket, + server_hostname=self.address[0]) + asyncio.async(ssl_connect_routine) def _cert_expiration(self, event): """Schedule an event for when the TLS certificate expires.""" - if not self.use_tls and not self.use_ssl: - return - if not self._der_cert: log.warn("TLS or SSL was enabled, but no certificate was found.") return @@ -942,13 +653,12 @@ class XMLStream(object): self.whitespace_keepalive_interval, self.send_raw, args=(' ',), - kwargs={'now': True}, repeat=True) def _remove_schedules(self, event): """Remove whitespace keepalive and certificate expiration schedules.""" - self.scheduler.remove('Whitespace Keepalive') - self.scheduler.remove('Certificate Expiration') + self.cancel_schedule('Whitespace Keepalive') + self.cancel_schedule('Certificate Expiration') def start_stream_handler(self, xml): """Perform any initialization actions, such as handshakes, @@ -1179,7 +889,7 @@ class XMLStream(object): else: self.exception(e) else: - self.event_queue.put(('event', handler, out_data)) + self.run_event(('event', handler, out_data)) if handler[2]: # If the handler is disposable, we will go ahead and # remove it now instead of waiting for it to be @@ -1191,12 +901,12 @@ class XMLStream(object): except: pass - def schedule(self, name, seconds, callback, args=None, - kwargs=None, repeat=False): + def schedule(self, name, seconds, callback, args=tuple(), + kwargs={}, repeat=False): """Schedule a callback function to execute after a given delay. :param name: A unique name for the scheduled callback. - :param seconds: The time in seconds to wait before executing. + :param seconds: The time in seconds to wait before executing. :param callback: A pointer to the function to execute. :param args: A tuple of arguments to pass to the function. :param kwargs: A dictionary of keyword arguments to pass to @@ -1204,8 +914,42 @@ class XMLStream(object): :param repeat: Flag indicating if the scheduled event should be reset and repeat after executing. """ - self.scheduler.add(name, seconds, callback, args, kwargs, - repeat, qpointer=self.event_queue) + loop = asyncio.get_event_loop() + cb = functools.partial(callback, *args, **kwargs) + if repeat: + handle = loop.call_later(seconds, self._execute_and_reschedule, + name, cb, seconds) + else: + handle = loop.call_later(seconds, self._execute_and_unschedule, + name, cb) + + # Save that handle, so we can just cancel this scheduled event by + # canceling scheduled_events[name] + self.scheduled_events[name] = handle + + def cancel_schedule(self, name): + try: + handle = self.scheduled_events.pop(name) + handle.cancel() + except KeyError: + log.debug("Tried to cancel unscheduled event: %s" % (name,)) + + def _execute_and_reschedule(self, name, cb, seconds): + """Simple method that calls the given callback, and then schedule itself to + be called after the given number of seconds. + """ + cb() + loop = asyncio.get_event_loop() + handle = loop.call_later(seconds, self._execute_and_reschedule, + name, cb, seconds) + self.scheduled_events[name] = handle + + def _execute_and_unschedule(self, name, cb): + """ + Execute the callback and remove the handler for it. + """ + cb() + del self.scheduled_events[name] def incoming_filter(self, xml): """Filter incoming XML objects before they are processed. @@ -1268,9 +1012,9 @@ class XMLStream(object): str_data = tostring(data.xml, xmlns=self.default_ns, stream=self, top_level=True) - self.send_raw(str_data, now) + self.send_raw(str_data) else: - self.send_raw(data, now) + self.send_raw(data) if mask is not None: return wait_for.wait(timeout) @@ -1296,57 +1040,17 @@ class XMLStream(object): timeout = self.response_timeout return self.send(tostring(data), mask, timeout, now) - def send_raw(self, data, now=False, reconnect=None): + def send_raw(self, data): """Send raw data across the stream. - :param string data: Any string value. - :param bool reconnect: Indicates if the stream should be - restarted if there is an error sending - the stanza. Used mainly for testing. - Defaults to :attr:`auto_reconnect`. + :param string data: Any bytes or utf-8 string value. """ - if now: - log.debug("SEND (IMMED): %s", data) - try: - data = data.encode('utf-8') - total = len(data) - sent = 0 - count = 0 - tries = 0 - with self.send_lock: - while sent < total and not self.stop.is_set(): - try: - sent += self.socket.send(data[sent:]) - count += 1 - except Socket.error as serr: - if serr.errno != errno.EINTR: - raise - except ssl.SSLError as serr: - if tries >= self.ssl_retry_max: - log.debug('SSL error: max retries reached') - self.exception(serr) - log.warning("Failed to send %s", data) - if reconnect is None: - reconnect = self.auto_reconnect - if not self.stop.is_set(): - self.disconnect(reconnect, - send_close=False) - log.warning('SSL write error: retrying') - if not self.stop.is_set(): - time.sleep(self.ssl_retry_delay) - tries += 1 - if count > 1: - log.debug('SENT: %d chunks', count) - except (Socket.error, ssl.SSLError) as serr: - self.event('socket_error', serr, direct=True) - log.warning("Failed to send %s", data) - if reconnect is None: - reconnect = self.auto_reconnect - if not self.stop.is_set(): - self.disconnect(reconnect, send_close=False) + if isinstance(data, str): + data = data.encode('utf-8') + if not self.transport: + logger.error("Cannot send data, we are not connected.") else: - self.send_queue.put(data) - return True + self.transport.write(data) def _start_thread(self, name, target, track=True): self.__thread[name] = threading.Thread(name=name, target=target) @@ -1386,179 +1090,6 @@ class XMLStream(object): def set_stop(self): self.stop.set() - # Unlock queues - self.event_queue.put(None) - self.send_queue.put(None) - - def _wait_for_threads(self): - with self.__thread_cond: - if self.__thread_count != 0: - log.debug("Waiting for %s threads to exit." % - self.__thread_count) - name = threading.current_thread().name - if name in self.__thread: - self._end_thread(name, early=True) - self.__thread_cond.wait(4) - if self.__thread_count != 0: - log.error("Hanged threads: %s" % threading.enumerate()) - log.error("This may be due to calling disconnect() " + \ - "from a non-threaded event handler. Be " + \ - "sure that event handlers that call " + \ - "disconnect() are registered using: " + \ - "add_event_handler(..., threaded=True)") - - def process(self, **kwargs): - """Initialize the XML streams and begin processing events. - - The number of threads used for processing stream events is determined - by :data:`HANDLER_THREADS`. - - :param bool block: If ``False``, then event dispatcher will run - in a separate thread, allowing for the stream to be - used in the background for another application. - Otherwise, ``process(block=True)`` blocks the current - thread. Defaults to ``False``. - :param bool threaded: **DEPRECATED** - If ``True``, then event dispatcher will run - in a separate thread, allowing for the stream to be - used in the background for another application. - Defaults to ``True``. This does **not** mean that no - threads are used at all if ``threaded=False``. - - Regardless of these threading options, these threads will - always exist: - - - The event queue processor - - The send queue processor - - The scheduler - """ - if 'threaded' in kwargs and 'block' in kwargs: - raise ValueError("process() called with both " + \ - "block and threaded arguments") - elif 'block' in kwargs: - threaded = not(kwargs.get('block', False)) - else: - threaded = kwargs.get('threaded', True) - - for t in range(0, HANDLER_THREADS): - log.debug("Starting HANDLER THREAD") - self._start_thread('event_thread_%s' % t, self._event_runner) - - self._start_thread('send_thread', self._send_thread) - self._start_thread('scheduler_thread', self._scheduler_thread) - - if threaded: - # Run the XML stream in the background for another application. - self._start_thread('read_thread', self._process, track=False) - else: - self._process() - - def _process(self): - """Start processing the XML streams. - - Processing will continue after any recoverable errors - if reconnections are allowed. - """ - - # The body of this loop will only execute once per connection. - # Additional passes will be made only if an error occurs and - # reconnecting is permitted. - while True: - shutdown = False - try: - # The call to self.__read_xml will block and prevent - # the body of the loop from running until a disconnect - # occurs. After any reconnection, the stream header will - # be resent and processing will resume. - while not self.stop.is_set(): - # Only process the stream while connected to the server - if not self.state.ensure('connected', wait=0.1): - break - # Ensure the stream header is sent for any - # new connections. - if not self.session_started_event.is_set(): - self.send_raw(self.stream_header, now=True) - if not self.__read_xml(): - # If the server terminated the stream, end processing - break - except KeyboardInterrupt: - log.debug("Keyboard Escape Detected in _process") - self.event('killed', direct=True) - shutdown = True - except SystemExit: - log.debug("SystemExit in _process") - shutdown = True - except (SyntaxError, ExpatError) as e: - log.error("Error reading from XML stream.") - self.exception(e) - except (Socket.error, ssl.SSLError) as serr: - self.event('socket_error', serr, direct=True) - log.error('Socket Error #%s: %s', serr.errno, serr.strerror) - except ValueError as e: - msg = e.message if hasattr(e, 'message') else e.args[0] - - if 'I/O operation on closed file' in msg: - log.error('Can not read from closed socket.') - else: - self.exception(e) - except Exception as e: - if not self.stop.is_set(): - log.error('Connection error.') - self.exception(e) - - if not shutdown and not self.stop.is_set() \ - and self.auto_reconnect: - self.reconnect() - else: - self.disconnect() - break - - def __read_xml(self): - """Parse the incoming XML stream - - Stream events are raised for each received stanza. - """ - depth = 0 - root = None - for event, xml in ET.iterparse(self.filesocket, (b'end', b'start')): - if event == b'start': - if depth == 0: - # We have received the start of the root element. - root = xml - log.debug('RECV: %s', tostring(root, xmlns=self.default_ns, - stream=self, - top_level=True, - open_only=True)) - # Perform any stream initialization actions, such - # as handshakes. - self.stream_end_event.clear() - self.start_stream_handler(root) - - # We have a successful stream connection, so reset - # exponential backoff for new reconnect attempts. - self.reconnect_delay = 1.0 - depth += 1 - if event == b'end': - depth -= 1 - if depth == 0: - # The stream's root element has closed, - # terminating the stream. - log.debug("End of stream recieved") - self.stream_end_event.set() - return False - elif depth == 1: - # We only raise events for stanzas that are direct - # children of the root element. - try: - self.__spawn_event(xml) - except RestartStream: - return True - if root is not None: - # Keep the root element empty of children to - # save on memory use. - root.clear() - log.debug("Ending read XML loop") - def _build_stanza(self, xml, default_ns=None): """Create a stanza object from a given XML object. @@ -1599,7 +1130,6 @@ class XMLStream(object): # Convert the raw XML object into a stanza object. If no registered # stanza type applies, a generic StanzaBase stanza will be used. stanza = self._build_stanza(xml) - for filter in self.__filters['in']: if stanza is not None: stanza = filter(stanza) @@ -1619,7 +1149,7 @@ class XMLStream(object): else: stanza_copy = stanza handler.prerun(stanza_copy) - self.event_queue.put(('stanza', handler, stanza_copy)) + self.run_event(('stanza', handler, stanza_copy)) try: if handler.check_delete(): self.__handlers.remove(handler) @@ -1651,71 +1181,45 @@ class XMLStream(object): else: self.exception(e) - def _event_runner(self): - """Process the event queue and execute handlers. - - The number of event runner threads is controlled by HANDLER_THREADS. - - Stream event handlers will all execute in this thread. Custom event - handlers may be spawned in individual threads. - """ - log.debug("Loading event runner") - try: - while not self.stop.is_set(): - event = self.event_queue.get() - if event is None: - continue - - etype, handler = event[0:2] - args = event[2:] - orig = copy.copy(args[0]) + def run_event(self, event): + etype, handler = event[0:2] + args = event[2:] + orig = copy.copy(args[0]) - if etype == 'stanza': - try: - handler.run(args[0]) - except Exception as e: - error_msg = 'Error processing stream handler: %s' - log.exception(error_msg, handler.name) - orig.exception(e) - elif etype == 'schedule': - name = args[2] - try: - log.debug('Scheduled event: %s: %s', name, args[0]) - handler(*args[0], **args[1]) - except Exception as e: - log.exception('Error processing scheduled task') - self.exception(e) - elif etype == 'event': - func, threaded, disposable = handler - try: - if threaded: - x = threading.Thread( - name="Event_%s" % str(func), - target=self._threaded_event_wrapper, - args=(func, args)) - x.daemon = self._use_daemons - x.start() - else: - func(*args) - except Exception as e: - error_msg = 'Error processing event handler: %s' - log.exception(error_msg, str(func)) - if hasattr(orig, 'exception'): - orig.exception(e) - else: - self.exception(e) - elif etype == 'quit': - log.debug("Quitting event runner thread") - break - except KeyboardInterrupt: - log.debug("Keyboard Escape Detected in _event_runner") - self.event('killed', direct=True) - self.disconnect() - except SystemExit: - self.disconnect() - self.event_queue.put(('quit', None, None)) - - self._end_thread('event runner') + if etype == 'stanza': + try: + handler.run(args[0]) + except Exception as e: + error_msg = 'Error processing stream handler: %s' + log.exception(error_msg, handler.name) + orig.exception(e) + elif etype == 'schedule': + name = args[2] + try: + log.debug('Scheduled event: %s: %s', name, args[0]) + handler(*args[0], **args[1]) + except Exception as e: + log.exception('Error processing scheduled task') + self.exception(e) + elif etype == 'event': + func, threaded, disposable = handler + try: + if threaded: + x = threading.Thread( + name="Event_%s" % str(func), + target=self._threaded_event_wrapper, + args=(func, args)) + x.daemon = self._use_daemons + x.start() + else: + func(*args) + except Exception as e: + error_msg = 'Error processing event handler: %s' + log.exception(error_msg, str(func)) + if hasattr(orig, 'exception'): + orig.exception(e) + else: + self.exception(e) def _send_thread(self): """Extract stanzas from the send queue and send them on the stream.""" @@ -1780,10 +1284,6 @@ class XMLStream(object): self._end_thread('send') - def _scheduler_thread(self): - self.scheduler.process(threaded=False) - self._end_thread('scheduler') - def exception(self, exception): """Process an unknown exception. @@ -1792,17 +1292,3 @@ class XMLStream(object): :param exception: An unhandled exception object. """ pass - - -# To comply with PEP8, method names now use underscores. -# Deprecated method names are re-mapped for backwards compatibility. -XMLStream.startTLS = XMLStream.start_tls -XMLStream.registerStanza = XMLStream.register_stanza -XMLStream.removeStanza = XMLStream.remove_stanza -XMLStream.registerHandler = XMLStream.register_handler -XMLStream.removeHandler = XMLStream.remove_handler -XMLStream.setSocket = XMLStream.set_socket -XMLStream.sendRaw = XMLStream.send_raw -XMLStream.getId = XMLStream.get_id -XMLStream.getNewId = XMLStream.new_id -XMLStream.sendXML = XMLStream.send_xml -- cgit v1.2.3