diff options
Diffstat (limited to 'sleekxmpp/xmlstream/xmlstream.py')
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 77 |
1 files changed, 65 insertions, 12 deletions
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index daa1af1a..56177556 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -35,7 +35,7 @@ from xml.parsers.expat import ExpatError import sleekxmpp from sleekxmpp.thirdparty.statemachine import StateMachine -from sleekxmpp.xmlstream import Scheduler, tostring +from sleekxmpp.xmlstream import Scheduler, tostring, cert from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase from sleekxmpp.xmlstream.handler import Waiter, XMLCallback from sleekxmpp.xmlstream.matcher import MatchXMLMask @@ -181,6 +181,9 @@ class XMLStream(object): #: The domain to try when querying DNS records. self.default_domain = '' + + #: The expected name of the server, for validation. + self._expected_server_name = '' #: The desired, or actual, address of the connected server. self.address = (host, int(port)) @@ -313,8 +316,9 @@ class XMLStream(object): self.dns_service = None self.add_event_handler('connected', self._handle_connected) - self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('disconnected', self._end_keepalive) + self.add_event_handler('session_start', self._start_keepalive) + self.add_event_handler('session_start', self._cert_expiration) def use_signals(self, signals=None): """Register signal handlers for ``SIGHUP`` and ``SIGTERM``. @@ -500,10 +504,17 @@ class XMLStream(object): self.socket.connect(self.address) if self.use_ssl and self.ssl_support: - cert = self.socket.getpeercert(binary_form=True) - cert = ssl.DER_cert_to_PEM_cert(cert) - log.debug('CERT: %s', cert) - self.event('ssl_cert', cert, direct=True) + self._der_cert = self.socket.getpeercert(binary_form=True) + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + log.debug('CERT: %s', pem_cert) + + self.event('ssl_cert', pem_cert, direct=True) + try: + cert.verify(self._expected_server_name, self._der_cert) + except cert.CertificateError as err: + log.error(err.message) + self.event('ssl_invalid_cert', cert, direct=True) + self.disconnect(send_close=False) self.set_socket(self.socket, ignore=True) #this event is where you should set your application state @@ -767,12 +778,27 @@ class XMLStream(object): self.socket.socket = ssl_socket else: self.socket = ssl_socket - self.socket.do_handshake() - cert = self.socket.getpeercert(binary_form=True) - cert = ssl.DER_cert_to_PEM_cert(cert) - log.debug('CERT: %s', cert) - self.event('ssl_cert', cert, direct=True) + try: + self.socket.do_handshake() + except: + log.error('CERT: Invalid certificate trust chain.') + self.event('ssl_invalid_chain', direct=True) + self.disconnect(self.auto_reconnect, send_close=False) + return False + + self._der_cert = self.socket.getpeercert(binary_form=True) + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + log.debug('CERT: %s', pem_cert) + self.event('ssl_cert', pem_cert, direct=True) + + try: + cert.verify(self._expected_server_name, self._der_cert) + except cert.CertificateError as err: + log.error(err.message) + self.event('ssl_invalid_cert', cert, direct=True) + if not self.event_handled('ssl_invalid_cert'): + self.disconnect(self.auto_reconnect, send_close=False) self.set_socket(self.socket) return True @@ -780,6 +806,26 @@ class XMLStream(object): log.warning("Tried to enable TLS, but ssl module not found.") return False + def _cert_expiration(self, event): + """Schedule an event for when the TLS certificate expires.""" + + def restart(): + log.warn("The server certificate has expired. Restarting.") + self.reconnect() + + cert_ttl = cert.get_ttl(self._der_cert) + if cert_ttl is None: + return + + if cert_ttl.days < 0: + log.warn('CERT: Certificate has expired.') + restart() + + log.info('CERT: Time until certificate expiration: %s' % cert_ttl) + self.schedule('Certificate Expiration', + cert_ttl.seconds, + restart) + def _start_keepalive(self, event): """Begin sending whitespace periodically to keep the connection alive. @@ -1298,9 +1344,16 @@ class XMLStream(object): except (Socket.error, ssl.SSLError) as serr: self.event('socket_error', serr, direct=True) log.error('Socket Error #%s: %s', serr.errno, serr.strerror) + except ValueError as e: + msg = e.message if hasattr(e, 'message') else e.args[0] + + if 'I/O operation on closed file' in msg: + log.error('Can not read from closed socket.') + else: + self.exception(e) except Exception as e: if not self.stop.is_set(): - log.exception('Connection error.') + log.error('Connection error.') self.exception(e) if not shutdown and not self.stop.is_set() \ |