diff options
Diffstat (limited to 'sleekxmpp/xmlstream/xmlstream.py')
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 295 |
1 files changed, 190 insertions, 105 deletions
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index 49f33933..f9ec4947 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -26,14 +26,12 @@ import time import random import weakref import uuid -try: - import queue -except ImportError: - import Queue as queue +import errno from xml.parsers.expat import ExpatError import sleekxmpp +from sleekxmpp.util import Queue, QueueEmpty, safedict from sleekxmpp.thirdparty.statemachine import StateMachine from sleekxmpp.xmlstream import Scheduler, tostring, cert from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase @@ -52,7 +50,7 @@ RESPONSE_TIMEOUT = 30 #: The time in seconds to wait for events from the event queue, and also the #: time between checks for the process stop signal. -WAIT_TIMEOUT = 0.1 +WAIT_TIMEOUT = 1.0 #: The number of threads to use to handle XML stream events. This is not the #: same as the number of custom event handling threads. @@ -61,9 +59,6 @@ WAIT_TIMEOUT = 0.1 #: a GIL increasing this value can provide better performance. HANDLER_THREADS = 1 -#: Flag indicating if the SSL library is available for use. -SSL_SUPPORT = True - #: The time in seconds to delay between attempts to resend data #: after an SSL error. SSL_RETRY_DELAY = 0.5 @@ -120,9 +115,6 @@ class XMLStream(object): """ def __init__(self, socket=None, host='', port=0): - #: Flag indicating if the SSL library is available for use. - self.ssl_support = SSL_SUPPORT - #: Most XMPP servers support TLSv1, but OpenFire in particular #: does not work well with it. For OpenFire, set #: :attr:`ssl_version` to use ``SSLv23``:: @@ -131,6 +123,11 @@ class XMLStream(object): #: xmpp.ssl_version = ssl.PROTOCOL_SSLv23 self.ssl_version = ssl.PROTOCOL_TLSv1 + #: The list of accepted ciphers, in OpenSSL Format. + #: It might be useful to override it for improved security + #: over the python defaults. + self.ciphers = None + #: Path to a file containing certificates for verifying the #: server SSL certificate. A non-``None`` value will trigger #: certificate checking. @@ -141,6 +138,17 @@ class XMLStream(object): #: be consulted, even if they are not in the provided file. self.ca_certs = None + #: Path to a file containing a client certificate to use for + #: authenticating via SASL EXTERNAL. If set, there must also + #: be a corresponding `:attr:keyfile` value. + self.certfile = None + + #: Path to a file containing the private key for the selected + #: client certificate to use for authenticating via SASL EXTERNAL. + self.keyfile = None + + self._der_cert = None + #: The time in seconds to wait for events from the event queue, #: and also the time between checks for the process stop signal. self.wait_timeout = WAIT_TIMEOUT @@ -184,6 +192,7 @@ class XMLStream(object): #: The expected name of the server, for validation. self._expected_server_name = '' + self._service_name = '' #: The desired, or actual, address of the connected server. self.address = (host, int(port)) @@ -215,6 +224,15 @@ class XMLStream(object): #: If set to ``True``, attempt to use IPv6. self.use_ipv6 = True + #: If set to ``True``, allow using the ``dnspython`` DNS library + #: if available. If set to ``False``, the builtin DNS resolver + #: will be used, even if ``dnspython`` is installed. + self.use_dnspython = True + + #: Use CDATA for escaping instead of XML entities. Defaults + #: to ``False``. + self.use_cdata = False + #: An optional dictionary of proxy settings. It may provide: #: :host: The host offering proxy services. #: :port: The port for the proxy service. @@ -270,10 +288,10 @@ class XMLStream(object): self.end_session_on_disconnect = True #: A queue of stream, custom, and scheduled events to be processed. - self.event_queue = queue.Queue() + self.event_queue = Queue() #: A queue of string data to be sent over the stream. - self.send_queue = queue.Queue() + self.send_queue = Queue(maxsize=256) self.send_queue_lock = threading.Lock() self.send_lock = threading.RLock() @@ -322,7 +340,7 @@ class XMLStream(object): #: ``_xmpp-client._tcp`` service. self.dns_service = None - self.add_event_handler('connected', self._handle_connected) + self.add_event_handler('connected', self._session_timeout_check) self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('session_start', self._cert_expiration) @@ -407,6 +425,8 @@ class XMLStream(object): :param reattempt: Flag indicating if the socket should reconnect after disconnections. """ + self.stop.clear() + if host and port: self.address = (host, int(port)) try: @@ -439,11 +459,12 @@ class XMLStream(object): def _connect(self, reattempt=True): self.scheduler.remove('Session timeout check') - self.stop.clear() - if self.reconnect_delay is None or not reattempt: + if self.reconnect_delay is None: delay = 1.0 - else: + self.reconnect_delay = delay + + if reattempt: delay = min(self.reconnect_delay * 2, self.reconnect_max_delay) delay = random.normalvariate(delay, delay * 0.1) log.debug('Waiting %s seconds before connecting.', delay) @@ -453,16 +474,18 @@ class XMLStream(object): time.sleep(0.1) elapsed += 0.1 except KeyboardInterrupt: - self.stop.set() + self.set_stop() return False except SystemExit: - self.stop.set() + self.set_stop() return False if self.default_domain: try: - self.address = self.pick_dns_answer(self.default_domain, - self.address[1]) + host, address, port = self.pick_dns_answer(self.default_domain, + self.address[1]) + self.address = (address, port) + self._service_name = host except StopIteration: log.debug("No remaining DNS records to try.") self.dns_answers = None @@ -490,17 +513,26 @@ class XMLStream(object): self.reconnect_delay = delay return False - if self.use_ssl and self.ssl_support: + if self.use_ssl: log.debug("Socket Wrapped for SSL") if self.ca_certs is None: cert_policy = ssl.CERT_NONE else: cert_policy = ssl.CERT_REQUIRED - ssl_socket = ssl.wrap_socket(self.socket, - ca_certs=self.ca_certs, - cert_reqs=cert_policy, - do_handshake_on_connect=False) + ssl_args = safedict({ + 'certfile': self.certfile, + 'keyfile': self.keyfile, + 'ca_certs': self.ca_certs, + 'cert_reqs': cert_policy, + 'do_handshake_on_connect': False, + "ssl_version": self.ssl_version + }) + + if sys.version_info >= (2, 7): + ssl_args['ciphers'] = self.ciphers + + ssl_socket = ssl.wrap_socket(self.socket, **ssl_args) if hasattr(self.socket, 'socket'): # We are using a testing socket, so preserve the top @@ -517,7 +549,7 @@ class XMLStream(object): log.debug("Connecting to %s:%s", domain, self.address[1]) self.socket.connect(self.address) - if self.use_ssl and self.ssl_support: + if self.use_ssl: try: self.socket.do_handshake() except (Socket.error, ssl.SSLError): @@ -538,7 +570,7 @@ class XMLStream(object): cert.verify(self._expected_server_name, self._der_cert) except cert.CertificateError as err: if not self.event_handled('ssl_invalid_cert'): - log.error(err.message) + log.error(err) self.disconnect(send_close=False) else: self.event('ssl_invalid_cert', @@ -547,8 +579,7 @@ class XMLStream(object): self.set_socket(self.socket, ignore=True) #this event is where you should set your application state - self.event("connected", direct=True) - self.reconnect_delay = 1.0 + self.event('connected', direct=True) return True except (Socket.error, ssl.SSLError) as serr: error_msg = "Could not connect to %s:%s. Socket Error #%s: %s" @@ -588,7 +619,7 @@ class XMLStream(object): headers = '\r\n'.join(headers) + '\r\n\r\n' try: - log.debug("Connecting to proxy: %s:%s", address) + log.debug("Connecting to proxy: %s:%s", *address) self.socket.connect(address) self.send_raw(headers, now=True) resp = '' @@ -599,6 +630,7 @@ class XMLStream(object): lines = resp.split('\r\n') if '200' not in lines[0]: self.event('proxy_error', resp) + self.event('connection_failed', direct=True) log.error('Proxy Error: %s', lines[0]) return False @@ -612,7 +644,7 @@ class XMLStream(object): serr.errno, serr.strerror) return False - def _handle_connected(self, event=None): + def _session_timeout_check(self, event=None): """ Add check to ensure that a session is established within a reasonable amount of time. @@ -661,6 +693,9 @@ class XMLStream(object): args=(reconnect, wait, send_close)) def _disconnect(self, reconnect=False, wait=None, send_close=True): + if not reconnect: + self.auto_reconnect = False + if self.end_session_on_disconnect or send_close: self.event('session_end', direct=True) @@ -684,7 +719,6 @@ class XMLStream(object): # closed in the other direction. If we didn't # send a stream footer we don't need to wait # since the server won't know to respond. - self.auto_reconnect = reconnect if send_close: log.info('Waiting for %s from server', self.stream_footer) self.stream_end_event.wait(4) @@ -692,7 +726,7 @@ class XMLStream(object): self.stream_end_event.set() if not self.auto_reconnect: - self.stop.set() + self.set_stop() if self._disconnect_wait_for_threads: self._wait_for_threads() @@ -704,9 +738,23 @@ class XMLStream(object): self.event('socket_error', serr, direct=True) finally: #clear your application state - self.event("disconnected", direct=True) + self.event('disconnected', direct=True) return True + def abort(self): + self.session_started_event.clear() + self.set_stop() + if self._disconnect_wait_for_threads: + self._wait_for_threads() + try: + self.socket.shutdown(Socket.SHUT_RDWR) + self.socket.close() + self.filesocket.close() + except Socket.error: + pass + self.state.transition_any(['connected', 'disconnected'], 'disconnected', func=lambda: True) + self.event("killed", direct=True) + def reconnect(self, reattempt=True, wait=False, send_close=True): """Reset the stream's state and reconnect to the server.""" log.debug("reconnecting...") @@ -789,56 +837,62 @@ class XMLStream(object): If the handshake is successful, the XML stream will need to be restarted. """ - if self.ssl_support: - log.info("Negotiating TLS") - log.info("Using SSL version: %s", str(self.ssl_version)) - if self.ca_certs is None: - cert_policy = ssl.CERT_NONE - else: - cert_policy = ssl.CERT_REQUIRED - - ssl_socket = ssl.wrap_socket(self.socket, - ssl_version=self.ssl_version, - do_handshake_on_connect=False, - ca_certs=self.ca_certs, - cert_reqs=cert_policy) + log.info("Negotiating TLS") + ssl_versions = {3: 'TLS 1.0', 1: 'SSL 3', 2: 'SSL 2/3'} + log.info("Using SSL version: %s", ssl_versions[self.ssl_version]) + if self.ca_certs is None: + cert_policy = ssl.CERT_NONE + else: + cert_policy = ssl.CERT_REQUIRED + + ssl_args = safedict({ + 'certfile': self.certfile, + 'keyfile': self.keyfile, + 'ca_certs': self.ca_certs, + 'cert_reqs': cert_policy, + 'do_handshake_on_connect': False, + "ssl_version": self.ssl_version + }) + + if sys.version_info >= (2, 7): + ssl_args['ciphers'] = self.ciphers + + ssl_socket = ssl.wrap_socket(self.socket, **ssl_args) + + if hasattr(self.socket, 'socket'): + # We are using a testing socket, so preserve the top + # layer of wrapping. + self.socket.socket = ssl_socket + else: + self.socket = ssl_socket - if hasattr(self.socket, 'socket'): - # We are using a testing socket, so preserve the top - # layer of wrapping. - self.socket.socket = ssl_socket + try: + self.socket.do_handshake() + except (Socket.error, ssl.SSLError): + log.error('CERT: Invalid certificate trust chain.') + if not self.event_handled('ssl_invalid_chain'): + self.disconnect(self.auto_reconnect, send_close=False) else: - self.socket = ssl_socket - - try: - self.socket.do_handshake() - except (Socket.error, ssl.SSLError): - log.error('CERT: Invalid certificate trust chain.') - if not self.event_handled('ssl_invalid_chain'): - self.disconnect(self.auto_reconnect, send_close=False) - else: - self.event('ssl_invalid_chain', direct=True) - return False + self._der_cert = self.socket.getpeercert(binary_form=True) + self.event('ssl_invalid_chain', direct=True) + return False - self._der_cert = self.socket.getpeercert(binary_form=True) - pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) - log.debug('CERT: %s', pem_cert) - self.event('ssl_cert', pem_cert, direct=True) + self._der_cert = self.socket.getpeercert(binary_form=True) + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + log.debug('CERT: %s', pem_cert) + self.event('ssl_cert', pem_cert, direct=True) - try: - cert.verify(self._expected_server_name, self._der_cert) - except cert.CertificateError as err: - if not self.event_handled('ssl_invalid_cert'): - log.error(err.message) - self.disconnect(self.auto_reconnect, send_close=False) - else: - self.event('ssl_invalid_cert', pem_cert, direct=True) + try: + cert.verify(self._expected_server_name, self._der_cert) + except cert.CertificateError as err: + if not self.event_handled('ssl_invalid_cert'): + log.error(err) + self.disconnect(self.auto_reconnect, send_close=False) + else: + self.event('ssl_invalid_cert', pem_cert, direct=True) - self.set_socket(self.socket) - return True - else: - log.warning("Tried to enable TLS, but ssl module not found.") - return False + self.set_socket(self.socket) + return True def _cert_expiration(self, event): """Schedule an event for when the TLS certificate expires.""" @@ -866,9 +920,15 @@ class XMLStream(object): log.warn('CERT: Certificate has expired.') restart() + try: + total_seconds = cert_ttl.total_seconds() + except AttributeError: + # for Python < 2.7 + total_seconds = (cert_ttl.microseconds + (cert_ttl.seconds + cert_ttl.days * 24 * 3600) * 10**6) / 10**6 + log.info('CERT: Time until certificate expiration: %s' % cert_ttl) self.schedule('Certificate Expiration', - cert_ttl.seconds, + total_seconds, restart) def _start_keepalive(self, event): @@ -882,12 +942,13 @@ class XMLStream(object): self.whitespace_keepalive_interval = 300 """ - self.schedule('Whitespace Keepalive', - self.whitespace_keepalive_interval, - self.send_raw, - args=(' ',), - kwargs={'now': True}, - repeat=True) + if self.whitespace_keepalive: + self.schedule('Whitespace Keepalive', + self.whitespace_keepalive_interval, + self.send_raw, + args=(' ',), + kwargs={'now': True}, + repeat=True) def _remove_schedules(self, event): """Remove whitespace keepalive and certificate expiration schedules.""" @@ -983,9 +1044,13 @@ class XMLStream(object): # and handler classes here. if name is None: - name = 'add_handler_%s' % self.getNewId() - self.registerHandler(XMLCallback(name, MatchXMLMask(mask), pointer, - once=disposable, instream=instream)) + name = 'add_handler_%s' % self.new_id() + self.register_handler( + XMLCallback(name, + MatchXMLMask(mask, self.default_ns), + pointer, + once=disposable, + instream=instream)) def register_handler(self, handler, before=None, after=None): """Add a stream event handler that will be executed when a matching @@ -1026,7 +1091,8 @@ class XMLStream(object): return resolve(domain, port, service=self.dns_service, resolver=resolver, - use_ipv6=self.use_ipv6) + use_ipv6=self.use_ipv6, + use_dnspython=self.use_dnspython) def pick_dns_answer(self, domain, port=None): """Pick a server and port from DNS answers. @@ -1087,7 +1153,7 @@ class XMLStream(object): """ return len(self.__event_handlers.get(name, [])) - def event(self, name, data={}, direct=False): + def event(self, name, data=None, direct=False): """Manually trigger a custom event. :param name: The name of the event to trigger. @@ -1098,6 +1164,11 @@ class XMLStream(object): event queue. All event handlers will run in the same thread. """ + if not data: + data = {} + + log.debug("Event triggered: " + name) + handlers = self.__event_handlers.get(name, []) for handler in handlers: #TODO: Data should not be copied, but should be read only, @@ -1202,7 +1273,9 @@ class XMLStream(object): data = filter(data) if data is None: return - str_data = str(data) + str_data = tostring(data.xml, xmlns=self.default_ns, + stream=self, + top_level=True) self.send_raw(str_data, now) else: self.send_raw(data, now) @@ -1267,6 +1340,9 @@ class XMLStream(object): if not self.stop.is_set(): time.sleep(self.ssl_retry_delay) tries += 1 + except Socket.error as serr: + if serr.errno != errno.EINTR: + raise if count > 1: log.debug('SENT: %d chunks', count) except (Socket.error, ssl.SSLError) as serr: @@ -1281,12 +1357,12 @@ class XMLStream(object): return True def _start_thread(self, name, target, track=True): - self.__active_threads.add(name) self.__thread[name] = threading.Thread(name=name, target=target) self.__thread[name].daemon = self._use_daemons self.__thread[name].start() if track: + self.__active_threads.add(name) with self.__thread_cond: self.__thread_count += 1 @@ -1315,6 +1391,13 @@ class XMLStream(object): if self.__thread_count == 0: self.__thread_cond.notify() + def set_stop(self): + self.stop.set() + + # Unlock queues + self.event_queue.put(None) + self.send_queue.put(None) + def _wait_for_threads(self): with self.__thread_cond: if self.__thread_count != 0: @@ -1458,6 +1541,10 @@ class XMLStream(object): # as handshakes. self.stream_end_event.clear() self.start_stream_handler(root) + + # We have a successful stream connection, so reset + # exponential backoff for new reconnect attempts. + self.reconnect_delay = 1.0 depth += 1 if event == b'end': depth -= 1 @@ -1583,11 +1670,7 @@ class XMLStream(object): log.debug("Loading event runner") try: while not self.stop.is_set(): - try: - wait = self.wait_timeout - event = self.event_queue.get(True, timeout=wait) - except queue.Empty: - event = None + event = self.event_queue.get() if event is None: continue @@ -1603,10 +1686,10 @@ class XMLStream(object): log.exception(error_msg, handler.name) orig.exception(e) elif etype == 'schedule': - name = args[1] + name = args[2] try: log.debug('Scheduled event: %s: %s', name, args[0]) - handler(*args[0]) + handler(*args[0], **args[1]) except Exception as e: log.exception('Error processing scheduled task') self.exception(e) @@ -1648,14 +1731,13 @@ class XMLStream(object): while not self.stop.is_set(): while not self.stop.is_set() and \ not self.session_started_event.is_set(): - self.session_started_event.wait(timeout=0.1) + self.session_started_event.wait(timeout=0.1) # Wait for session start if self.__failed_send_stanza is not None: data = self.__failed_send_stanza self.__failed_send_stanza = None else: - try: - data = self.send_queue.get(True, 1) - except queue.Empty: + data = self.send_queue.get() # Wait for data to send + if data is None: continue log.debug("SEND: %s", data) enc_data = data.encode('utf-8') @@ -1682,6 +1764,9 @@ class XMLStream(object): if not self.stop.is_set(): time.sleep(self.ssl_retry_delay) tries += 1 + except Socket.error as serr: + if serr.errno != errno.EINTR: + raise if count > 1: log.debug('SENT: %d chunks', count) self.send_queue.task_done() |