summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream/xmlstream.py
diff options
context:
space:
mode:
Diffstat (limited to 'sleekxmpp/xmlstream/xmlstream.py')
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py77
1 files changed, 65 insertions, 12 deletions
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index daa1af1a..56177556 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -35,7 +35,7 @@ from xml.parsers.expat import ExpatError
import sleekxmpp
from sleekxmpp.thirdparty.statemachine import StateMachine
-from sleekxmpp.xmlstream import Scheduler, tostring
+from sleekxmpp.xmlstream import Scheduler, tostring, cert
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
from sleekxmpp.xmlstream.handler import Waiter, XMLCallback
from sleekxmpp.xmlstream.matcher import MatchXMLMask
@@ -181,6 +181,9 @@ class XMLStream(object):
#: The domain to try when querying DNS records.
self.default_domain = ''
+
+ #: The expected name of the server, for validation.
+ self._expected_server_name = ''
#: The desired, or actual, address of the connected server.
self.address = (host, int(port))
@@ -313,8 +316,9 @@ class XMLStream(object):
self.dns_service = None
self.add_event_handler('connected', self._handle_connected)
- self.add_event_handler('session_start', self._start_keepalive)
self.add_event_handler('disconnected', self._end_keepalive)
+ self.add_event_handler('session_start', self._start_keepalive)
+ self.add_event_handler('session_start', self._cert_expiration)
def use_signals(self, signals=None):
"""Register signal handlers for ``SIGHUP`` and ``SIGTERM``.
@@ -500,10 +504,17 @@ class XMLStream(object):
self.socket.connect(self.address)
if self.use_ssl and self.ssl_support:
- cert = self.socket.getpeercert(binary_form=True)
- cert = ssl.DER_cert_to_PEM_cert(cert)
- log.debug('CERT: %s', cert)
- self.event('ssl_cert', cert, direct=True)
+ self._der_cert = self.socket.getpeercert(binary_form=True)
+ pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
+ log.debug('CERT: %s', pem_cert)
+
+ self.event('ssl_cert', pem_cert, direct=True)
+ try:
+ cert.verify(self._expected_server_name, self._der_cert)
+ except cert.CertificateError as err:
+ log.error(err.message)
+ self.event('ssl_invalid_cert', cert, direct=True)
+ self.disconnect(send_close=False)
self.set_socket(self.socket, ignore=True)
#this event is where you should set your application state
@@ -767,12 +778,27 @@ class XMLStream(object):
self.socket.socket = ssl_socket
else:
self.socket = ssl_socket
- self.socket.do_handshake()
- cert = self.socket.getpeercert(binary_form=True)
- cert = ssl.DER_cert_to_PEM_cert(cert)
- log.debug('CERT: %s', cert)
- self.event('ssl_cert', cert, direct=True)
+ try:
+ self.socket.do_handshake()
+ except:
+ log.error('CERT: Invalid certificate trust chain.')
+ self.event('ssl_invalid_chain', direct=True)
+ self.disconnect(self.auto_reconnect, send_close=False)
+ return False
+
+ self._der_cert = self.socket.getpeercert(binary_form=True)
+ pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
+ log.debug('CERT: %s', pem_cert)
+ self.event('ssl_cert', pem_cert, direct=True)
+
+ try:
+ cert.verify(self._expected_server_name, self._der_cert)
+ except cert.CertificateError as err:
+ log.error(err.message)
+ self.event('ssl_invalid_cert', cert, direct=True)
+ if not self.event_handled('ssl_invalid_cert'):
+ self.disconnect(self.auto_reconnect, send_close=False)
self.set_socket(self.socket)
return True
@@ -780,6 +806,26 @@ class XMLStream(object):
log.warning("Tried to enable TLS, but ssl module not found.")
return False
+ def _cert_expiration(self, event):
+ """Schedule an event for when the TLS certificate expires."""
+
+ def restart():
+ log.warn("The server certificate has expired. Restarting.")
+ self.reconnect()
+
+ cert_ttl = cert.get_ttl(self._der_cert)
+ if cert_ttl is None:
+ return
+
+ if cert_ttl.days < 0:
+ log.warn('CERT: Certificate has expired.')
+ restart()
+
+ log.info('CERT: Time until certificate expiration: %s' % cert_ttl)
+ self.schedule('Certificate Expiration',
+ cert_ttl.seconds,
+ restart)
+
def _start_keepalive(self, event):
"""Begin sending whitespace periodically to keep the connection alive.
@@ -1298,9 +1344,16 @@ class XMLStream(object):
except (Socket.error, ssl.SSLError) as serr:
self.event('socket_error', serr, direct=True)
log.error('Socket Error #%s: %s', serr.errno, serr.strerror)
+ except ValueError as e:
+ msg = e.message if hasattr(e, 'message') else e.args[0]
+
+ if 'I/O operation on closed file' in msg:
+ log.error('Can not read from closed socket.')
+ else:
+ self.exception(e)
except Exception as e:
if not self.stop.is_set():
- log.exception('Connection error.')
+ log.error('Connection error.')
self.exception(e)
if not shutdown and not self.stop.is_set() \