From 48dd01b0bb7db1d93bf2d21e681939bfcd2f1297 Mon Sep 17 00:00:00 2001
From: Lance Stout <lancestout@gmail.com>
Date: Fri, 8 Jun 2012 09:31:44 -0700
Subject: Ensure that all SSL cert error handling is overridable using event
 handlers.

Relevant events:

    ssl_invalid_cert
    ssl_invalid_chain
    ssl_expired_cert
---
 sleekxmpp/xmlstream/xmlstream.py | 37 +++++++++++++++++++++++++++++--------
 1 file changed, 29 insertions(+), 8 deletions(-)

(limited to 'sleekxmpp')

diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index ac0fc256..7376d56d 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -493,7 +493,8 @@ class XMLStream(object):
 
             ssl_socket = ssl.wrap_socket(self.socket,
                                          ca_certs=self.ca_certs,
-                                         cert_reqs=cert_policy)
+                                         cert_reqs=cert_policy,
+                                         do_handshake_on_connect=False)
 
             if hasattr(self.socket, 'socket'):
                 # We are using a testing socket, so preserve the top
@@ -510,6 +511,17 @@ class XMLStream(object):
                 log.debug("Connecting to %s:%s", domain, self.address[1])
                 self.socket.connect(self.address)
 
+                try:
+                    self.socket.do_handshake()
+                except:
+                    log.error('CERT: Invalid certificate trust chain.')
+                    if not self.event_handled('ssl_invalid_chain'):
+                        self.disconnect(self.auto_reconnect, send_close=False)
+                    else:
+                        self.event('ssl_invalid_chain', direct=True)
+                    return False
+
+
                 if self.use_ssl and self.ssl_support:
                     self._der_cert = self.socket.getpeercert(binary_form=True)
                     pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
@@ -520,8 +532,10 @@ class XMLStream(object):
                         cert.verify(self._expected_server_name, self._der_cert)
                     except cert.CertificateError as err:
                         log.error(err.message)
-                        self.event('ssl_invalid_cert', cert, direct=True)
-                        self.disconnect(send_close=False)
+                        if not self.event_handled('ssl_invalid_cert'):
+                            self.disconnect(send_close=False)
+                        else:
+                            self.event('ssl_invalid_cert', cert, direct=True)
 
             self.set_socket(self.socket, ignore=True)
             #this event is where you should set your application state
@@ -790,8 +804,10 @@ class XMLStream(object):
                 self.socket.do_handshake()
             except:
                 log.error('CERT: Invalid certificate trust chain.')
-                self.event('ssl_invalid_chain', direct=True)
-                self.disconnect(self.auto_reconnect, send_close=False)
+                if not self.event_handled('ssl_invalid_chain'):
+                    self.disconnect(self.auto_reconnect, send_close=False)
+                else:
+                    self.event('ssl_invalid_chain', direct=True)
                 return False
 
             self._der_cert = self.socket.getpeercert(binary_form=True)
@@ -803,9 +819,10 @@ class XMLStream(object):
                 cert.verify(self._expected_server_name, self._der_cert)
             except cert.CertificateError as err:
                 log.error(err.message)
-                self.event('ssl_invalid_cert', cert, direct=True)
                 if not self.event_handled('ssl_invalid_cert'):
                     self.disconnect(self.auto_reconnect, send_close=False)
+                else:
+                    self.event('ssl_invalid_cert', cert, direct=True)
 
             self.set_socket(self.socket)
             return True
@@ -820,8 +837,12 @@ class XMLStream(object):
             return
 
         def restart():
-            log.warn("The server certificate has expired. Restarting.")
-            self.reconnect()
+            if not self.event_handled('ssl_expired_cert'):
+                log.warn("The server certificate has expired. Restarting.")
+                self.reconnect()
+            else:
+                pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
+                self.event('ssl_expired_cert', pem_cert)
 
         cert_ttl = cert.get_ttl(self._der_cert)
         if cert_ttl is None:
-- 
cgit v1.2.3