summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream/xmlstream.py
diff options
context:
space:
mode:
Diffstat (limited to 'sleekxmpp/xmlstream/xmlstream.py')
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py295
1 files changed, 190 insertions, 105 deletions
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index 49f33933..f9ec4947 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -26,14 +26,12 @@ import time
import random
import weakref
import uuid
-try:
- import queue
-except ImportError:
- import Queue as queue
+import errno
from xml.parsers.expat import ExpatError
import sleekxmpp
+from sleekxmpp.util import Queue, QueueEmpty, safedict
from sleekxmpp.thirdparty.statemachine import StateMachine
from sleekxmpp.xmlstream import Scheduler, tostring, cert
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
@@ -52,7 +50,7 @@ RESPONSE_TIMEOUT = 30
#: The time in seconds to wait for events from the event queue, and also the
#: time between checks for the process stop signal.
-WAIT_TIMEOUT = 0.1
+WAIT_TIMEOUT = 1.0
#: The number of threads to use to handle XML stream events. This is not the
#: same as the number of custom event handling threads.
@@ -61,9 +59,6 @@ WAIT_TIMEOUT = 0.1
#: a GIL increasing this value can provide better performance.
HANDLER_THREADS = 1
-#: Flag indicating if the SSL library is available for use.
-SSL_SUPPORT = True
-
#: The time in seconds to delay between attempts to resend data
#: after an SSL error.
SSL_RETRY_DELAY = 0.5
@@ -120,9 +115,6 @@ class XMLStream(object):
"""
def __init__(self, socket=None, host='', port=0):
- #: Flag indicating if the SSL library is available for use.
- self.ssl_support = SSL_SUPPORT
-
#: Most XMPP servers support TLSv1, but OpenFire in particular
#: does not work well with it. For OpenFire, set
#: :attr:`ssl_version` to use ``SSLv23``::
@@ -131,6 +123,11 @@ class XMLStream(object):
#: xmpp.ssl_version = ssl.PROTOCOL_SSLv23
self.ssl_version = ssl.PROTOCOL_TLSv1
+ #: The list of accepted ciphers, in OpenSSL Format.
+ #: It might be useful to override it for improved security
+ #: over the python defaults.
+ self.ciphers = None
+
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
@@ -141,6 +138,17 @@ class XMLStream(object):
#: be consulted, even if they are not in the provided file.
self.ca_certs = None
+ #: Path to a file containing a client certificate to use for
+ #: authenticating via SASL EXTERNAL. If set, there must also
+ #: be a corresponding `:attr:keyfile` value.
+ self.certfile = None
+
+ #: Path to a file containing the private key for the selected
+ #: client certificate to use for authenticating via SASL EXTERNAL.
+ self.keyfile = None
+
+ self._der_cert = None
+
#: The time in seconds to wait for events from the event queue,
#: and also the time between checks for the process stop signal.
self.wait_timeout = WAIT_TIMEOUT
@@ -184,6 +192,7 @@ class XMLStream(object):
#: The expected name of the server, for validation.
self._expected_server_name = ''
+ self._service_name = ''
#: The desired, or actual, address of the connected server.
self.address = (host, int(port))
@@ -215,6 +224,15 @@ class XMLStream(object):
#: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True
+ #: If set to ``True``, allow using the ``dnspython`` DNS library
+ #: if available. If set to ``False``, the builtin DNS resolver
+ #: will be used, even if ``dnspython`` is installed.
+ self.use_dnspython = True
+
+ #: Use CDATA for escaping instead of XML entities. Defaults
+ #: to ``False``.
+ self.use_cdata = False
+
#: An optional dictionary of proxy settings. It may provide:
#: :host: The host offering proxy services.
#: :port: The port for the proxy service.
@@ -270,10 +288,10 @@ class XMLStream(object):
self.end_session_on_disconnect = True
#: A queue of stream, custom, and scheduled events to be processed.
- self.event_queue = queue.Queue()
+ self.event_queue = Queue()
#: A queue of string data to be sent over the stream.
- self.send_queue = queue.Queue()
+ self.send_queue = Queue(maxsize=256)
self.send_queue_lock = threading.Lock()
self.send_lock = threading.RLock()
@@ -322,7 +340,7 @@ class XMLStream(object):
#: ``_xmpp-client._tcp`` service.
self.dns_service = None
- self.add_event_handler('connected', self._handle_connected)
+ self.add_event_handler('connected', self._session_timeout_check)
self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive)
self.add_event_handler('session_start', self._cert_expiration)
@@ -407,6 +425,8 @@ class XMLStream(object):
:param reattempt: Flag indicating if the socket should reconnect
after disconnections.
"""
+ self.stop.clear()
+
if host and port:
self.address = (host, int(port))
try:
@@ -439,11 +459,12 @@ class XMLStream(object):
def _connect(self, reattempt=True):
self.scheduler.remove('Session timeout check')
- self.stop.clear()
- if self.reconnect_delay is None or not reattempt:
+ if self.reconnect_delay is None:
delay = 1.0
- else:
+ self.reconnect_delay = delay
+
+ if reattempt:
delay = min(self.reconnect_delay * 2, self.reconnect_max_delay)
delay = random.normalvariate(delay, delay * 0.1)
log.debug('Waiting %s seconds before connecting.', delay)
@@ -453,16 +474,18 @@ class XMLStream(object):
time.sleep(0.1)
elapsed += 0.1
except KeyboardInterrupt:
- self.stop.set()
+ self.set_stop()
return False
except SystemExit:
- self.stop.set()
+ self.set_stop()
return False
if self.default_domain:
try:
- self.address = self.pick_dns_answer(self.default_domain,
- self.address[1])
+ host, address, port = self.pick_dns_answer(self.default_domain,
+ self.address[1])
+ self.address = (address, port)
+ self._service_name = host
except StopIteration:
log.debug("No remaining DNS records to try.")
self.dns_answers = None
@@ -490,17 +513,26 @@ class XMLStream(object):
self.reconnect_delay = delay
return False
- if self.use_ssl and self.ssl_support:
+ if self.use_ssl:
log.debug("Socket Wrapped for SSL")
if self.ca_certs is None:
cert_policy = ssl.CERT_NONE
else:
cert_policy = ssl.CERT_REQUIRED
- ssl_socket = ssl.wrap_socket(self.socket,
- ca_certs=self.ca_certs,
- cert_reqs=cert_policy,
- do_handshake_on_connect=False)
+ ssl_args = safedict({
+ 'certfile': self.certfile,
+ 'keyfile': self.keyfile,
+ 'ca_certs': self.ca_certs,
+ 'cert_reqs': cert_policy,
+ 'do_handshake_on_connect': False,
+ "ssl_version": self.ssl_version
+ })
+
+ if sys.version_info >= (2, 7):
+ ssl_args['ciphers'] = self.ciphers
+
+ ssl_socket = ssl.wrap_socket(self.socket, **ssl_args)
if hasattr(self.socket, 'socket'):
# We are using a testing socket, so preserve the top
@@ -517,7 +549,7 @@ class XMLStream(object):
log.debug("Connecting to %s:%s", domain, self.address[1])
self.socket.connect(self.address)
- if self.use_ssl and self.ssl_support:
+ if self.use_ssl:
try:
self.socket.do_handshake()
except (Socket.error, ssl.SSLError):
@@ -538,7 +570,7 @@ class XMLStream(object):
cert.verify(self._expected_server_name, self._der_cert)
except cert.CertificateError as err:
if not self.event_handled('ssl_invalid_cert'):
- log.error(err.message)
+ log.error(err)
self.disconnect(send_close=False)
else:
self.event('ssl_invalid_cert',
@@ -547,8 +579,7 @@ class XMLStream(object):
self.set_socket(self.socket, ignore=True)
#this event is where you should set your application state
- self.event("connected", direct=True)
- self.reconnect_delay = 1.0
+ self.event('connected', direct=True)
return True
except (Socket.error, ssl.SSLError) as serr:
error_msg = "Could not connect to %s:%s. Socket Error #%s: %s"
@@ -588,7 +619,7 @@ class XMLStream(object):
headers = '\r\n'.join(headers) + '\r\n\r\n'
try:
- log.debug("Connecting to proxy: %s:%s", address)
+ log.debug("Connecting to proxy: %s:%s", *address)
self.socket.connect(address)
self.send_raw(headers, now=True)
resp = ''
@@ -599,6 +630,7 @@ class XMLStream(object):
lines = resp.split('\r\n')
if '200' not in lines[0]:
self.event('proxy_error', resp)
+ self.event('connection_failed', direct=True)
log.error('Proxy Error: %s', lines[0])
return False
@@ -612,7 +644,7 @@ class XMLStream(object):
serr.errno, serr.strerror)
return False
- def _handle_connected(self, event=None):
+ def _session_timeout_check(self, event=None):
"""
Add check to ensure that a session is established within
a reasonable amount of time.
@@ -661,6 +693,9 @@ class XMLStream(object):
args=(reconnect, wait, send_close))
def _disconnect(self, reconnect=False, wait=None, send_close=True):
+ if not reconnect:
+ self.auto_reconnect = False
+
if self.end_session_on_disconnect or send_close:
self.event('session_end', direct=True)
@@ -684,7 +719,6 @@ class XMLStream(object):
# closed in the other direction. If we didn't
# send a stream footer we don't need to wait
# since the server won't know to respond.
- self.auto_reconnect = reconnect
if send_close:
log.info('Waiting for %s from server', self.stream_footer)
self.stream_end_event.wait(4)
@@ -692,7 +726,7 @@ class XMLStream(object):
self.stream_end_event.set()
if not self.auto_reconnect:
- self.stop.set()
+ self.set_stop()
if self._disconnect_wait_for_threads:
self._wait_for_threads()
@@ -704,9 +738,23 @@ class XMLStream(object):
self.event('socket_error', serr, direct=True)
finally:
#clear your application state
- self.event("disconnected", direct=True)
+ self.event('disconnected', direct=True)
return True
+ def abort(self):
+ self.session_started_event.clear()
+ self.set_stop()
+ if self._disconnect_wait_for_threads:
+ self._wait_for_threads()
+ try:
+ self.socket.shutdown(Socket.SHUT_RDWR)
+ self.socket.close()
+ self.filesocket.close()
+ except Socket.error:
+ pass
+ self.state.transition_any(['connected', 'disconnected'], 'disconnected', func=lambda: True)
+ self.event("killed", direct=True)
+
def reconnect(self, reattempt=True, wait=False, send_close=True):
"""Reset the stream's state and reconnect to the server."""
log.debug("reconnecting...")
@@ -789,56 +837,62 @@ class XMLStream(object):
If the handshake is successful, the XML stream will need
to be restarted.
"""
- if self.ssl_support:
- log.info("Negotiating TLS")
- log.info("Using SSL version: %s", str(self.ssl_version))
- if self.ca_certs is None:
- cert_policy = ssl.CERT_NONE
- else:
- cert_policy = ssl.CERT_REQUIRED
-
- ssl_socket = ssl.wrap_socket(self.socket,
- ssl_version=self.ssl_version,
- do_handshake_on_connect=False,
- ca_certs=self.ca_certs,
- cert_reqs=cert_policy)
+ log.info("Negotiating TLS")
+ ssl_versions = {3: 'TLS 1.0', 1: 'SSL 3', 2: 'SSL 2/3'}
+ log.info("Using SSL version: %s", ssl_versions[self.ssl_version])
+ if self.ca_certs is None:
+ cert_policy = ssl.CERT_NONE
+ else:
+ cert_policy = ssl.CERT_REQUIRED
+
+ ssl_args = safedict({
+ 'certfile': self.certfile,
+ 'keyfile': self.keyfile,
+ 'ca_certs': self.ca_certs,
+ 'cert_reqs': cert_policy,
+ 'do_handshake_on_connect': False,
+ "ssl_version": self.ssl_version
+ })
+
+ if sys.version_info >= (2, 7):
+ ssl_args['ciphers'] = self.ciphers
+
+ ssl_socket = ssl.wrap_socket(self.socket, **ssl_args)
+
+ if hasattr(self.socket, 'socket'):
+ # We are using a testing socket, so preserve the top
+ # layer of wrapping.
+ self.socket.socket = ssl_socket
+ else:
+ self.socket = ssl_socket
- if hasattr(self.socket, 'socket'):
- # We are using a testing socket, so preserve the top
- # layer of wrapping.
- self.socket.socket = ssl_socket
+ try:
+ self.socket.do_handshake()
+ except (Socket.error, ssl.SSLError):
+ log.error('CERT: Invalid certificate trust chain.')
+ if not self.event_handled('ssl_invalid_chain'):
+ self.disconnect(self.auto_reconnect, send_close=False)
else:
- self.socket = ssl_socket
-
- try:
- self.socket.do_handshake()
- except (Socket.error, ssl.SSLError):
- log.error('CERT: Invalid certificate trust chain.')
- if not self.event_handled('ssl_invalid_chain'):
- self.disconnect(self.auto_reconnect, send_close=False)
- else:
- self.event('ssl_invalid_chain', direct=True)
- return False
+ self._der_cert = self.socket.getpeercert(binary_form=True)
+ self.event('ssl_invalid_chain', direct=True)
+ return False
- self._der_cert = self.socket.getpeercert(binary_form=True)
- pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
- log.debug('CERT: %s', pem_cert)
- self.event('ssl_cert', pem_cert, direct=True)
+ self._der_cert = self.socket.getpeercert(binary_form=True)
+ pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
+ log.debug('CERT: %s', pem_cert)
+ self.event('ssl_cert', pem_cert, direct=True)
- try:
- cert.verify(self._expected_server_name, self._der_cert)
- except cert.CertificateError as err:
- if not self.event_handled('ssl_invalid_cert'):
- log.error(err.message)
- self.disconnect(self.auto_reconnect, send_close=False)
- else:
- self.event('ssl_invalid_cert', pem_cert, direct=True)
+ try:
+ cert.verify(self._expected_server_name, self._der_cert)
+ except cert.CertificateError as err:
+ if not self.event_handled('ssl_invalid_cert'):
+ log.error(err)
+ self.disconnect(self.auto_reconnect, send_close=False)
+ else:
+ self.event('ssl_invalid_cert', pem_cert, direct=True)
- self.set_socket(self.socket)
- return True
- else:
- log.warning("Tried to enable TLS, but ssl module not found.")
- return False
+ self.set_socket(self.socket)
+ return True
def _cert_expiration(self, event):
"""Schedule an event for when the TLS certificate expires."""
@@ -866,9 +920,15 @@ class XMLStream(object):
log.warn('CERT: Certificate has expired.')
restart()
+ try:
+ total_seconds = cert_ttl.total_seconds()
+ except AttributeError:
+ # for Python < 2.7
+ total_seconds = (cert_ttl.microseconds + (cert_ttl.seconds + cert_ttl.days * 24 * 3600) * 10**6) / 10**6
+
log.info('CERT: Time until certificate expiration: %s' % cert_ttl)
self.schedule('Certificate Expiration',
- cert_ttl.seconds,
+ total_seconds,
restart)
def _start_keepalive(self, event):
@@ -882,12 +942,13 @@ class XMLStream(object):
self.whitespace_keepalive_interval = 300
"""
- self.schedule('Whitespace Keepalive',
- self.whitespace_keepalive_interval,
- self.send_raw,
- args=(' ',),
- kwargs={'now': True},
- repeat=True)
+ if self.whitespace_keepalive:
+ self.schedule('Whitespace Keepalive',
+ self.whitespace_keepalive_interval,
+ self.send_raw,
+ args=(' ',),
+ kwargs={'now': True},
+ repeat=True)
def _remove_schedules(self, event):
"""Remove whitespace keepalive and certificate expiration schedules."""
@@ -983,9 +1044,13 @@ class XMLStream(object):
# and handler classes here.
if name is None:
- name = 'add_handler_%s' % self.getNewId()
- self.registerHandler(XMLCallback(name, MatchXMLMask(mask), pointer,
- once=disposable, instream=instream))
+ name = 'add_handler_%s' % self.new_id()
+ self.register_handler(
+ XMLCallback(name,
+ MatchXMLMask(mask, self.default_ns),
+ pointer,
+ once=disposable,
+ instream=instream))
def register_handler(self, handler, before=None, after=None):
"""Add a stream event handler that will be executed when a matching
@@ -1026,7 +1091,8 @@ class XMLStream(object):
return resolve(domain, port, service=self.dns_service,
resolver=resolver,
- use_ipv6=self.use_ipv6)
+ use_ipv6=self.use_ipv6,
+ use_dnspython=self.use_dnspython)
def pick_dns_answer(self, domain, port=None):
"""Pick a server and port from DNS answers.
@@ -1087,7 +1153,7 @@ class XMLStream(object):
"""
return len(self.__event_handlers.get(name, []))
- def event(self, name, data={}, direct=False):
+ def event(self, name, data=None, direct=False):
"""Manually trigger a custom event.
:param name: The name of the event to trigger.
@@ -1098,6 +1164,11 @@ class XMLStream(object):
event queue. All event handlers will run in the
same thread.
"""
+ if not data:
+ data = {}
+
+ log.debug("Event triggered: " + name)
+
handlers = self.__event_handlers.get(name, [])
for handler in handlers:
#TODO: Data should not be copied, but should be read only,
@@ -1202,7 +1273,9 @@ class XMLStream(object):
data = filter(data)
if data is None:
return
- str_data = str(data)
+ str_data = tostring(data.xml, xmlns=self.default_ns,
+ stream=self,
+ top_level=True)
self.send_raw(str_data, now)
else:
self.send_raw(data, now)
@@ -1267,6 +1340,9 @@ class XMLStream(object):
if not self.stop.is_set():
time.sleep(self.ssl_retry_delay)
tries += 1
+ except Socket.error as serr:
+ if serr.errno != errno.EINTR:
+ raise
if count > 1:
log.debug('SENT: %d chunks', count)
except (Socket.error, ssl.SSLError) as serr:
@@ -1281,12 +1357,12 @@ class XMLStream(object):
return True
def _start_thread(self, name, target, track=True):
- self.__active_threads.add(name)
self.__thread[name] = threading.Thread(name=name, target=target)
self.__thread[name].daemon = self._use_daemons
self.__thread[name].start()
if track:
+ self.__active_threads.add(name)
with self.__thread_cond:
self.__thread_count += 1
@@ -1315,6 +1391,13 @@ class XMLStream(object):
if self.__thread_count == 0:
self.__thread_cond.notify()
+ def set_stop(self):
+ self.stop.set()
+
+ # Unlock queues
+ self.event_queue.put(None)
+ self.send_queue.put(None)
+
def _wait_for_threads(self):
with self.__thread_cond:
if self.__thread_count != 0:
@@ -1458,6 +1541,10 @@ class XMLStream(object):
# as handshakes.
self.stream_end_event.clear()
self.start_stream_handler(root)
+
+ # We have a successful stream connection, so reset
+ # exponential backoff for new reconnect attempts.
+ self.reconnect_delay = 1.0
depth += 1
if event == b'end':
depth -= 1
@@ -1583,11 +1670,7 @@ class XMLStream(object):
log.debug("Loading event runner")
try:
while not self.stop.is_set():
- try:
- wait = self.wait_timeout
- event = self.event_queue.get(True, timeout=wait)
- except queue.Empty:
- event = None
+ event = self.event_queue.get()
if event is None:
continue
@@ -1603,10 +1686,10 @@ class XMLStream(object):
log.exception(error_msg, handler.name)
orig.exception(e)
elif etype == 'schedule':
- name = args[1]
+ name = args[2]
try:
log.debug('Scheduled event: %s: %s', name, args[0])
- handler(*args[0])
+ handler(*args[0], **args[1])
except Exception as e:
log.exception('Error processing scheduled task')
self.exception(e)
@@ -1648,14 +1731,13 @@ class XMLStream(object):
while not self.stop.is_set():
while not self.stop.is_set() and \
not self.session_started_event.is_set():
- self.session_started_event.wait(timeout=0.1)
+ self.session_started_event.wait(timeout=0.1) # Wait for session start
if self.__failed_send_stanza is not None:
data = self.__failed_send_stanza
self.__failed_send_stanza = None
else:
- try:
- data = self.send_queue.get(True, 1)
- except queue.Empty:
+ data = self.send_queue.get() # Wait for data to send
+ if data is None:
continue
log.debug("SEND: %s", data)
enc_data = data.encode('utf-8')
@@ -1682,6 +1764,9 @@ class XMLStream(object):
if not self.stop.is_set():
time.sleep(self.ssl_retry_delay)
tries += 1
+ except Socket.error as serr:
+ if serr.errno != errno.EINTR:
+ raise
if count > 1:
log.debug('SENT: %d chunks', count)
self.send_queue.task_done()