diff options
Diffstat (limited to 'slixmpp/xmlstream/xmlstream.py')
-rw-r--r-- | slixmpp/xmlstream/xmlstream.py | 46 |
1 files changed, 21 insertions, 25 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index d5dce586..1fa07b0c 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -64,12 +64,12 @@ class XMLStream(asyncio.BaseProtocol): :param int port: The port to use for the connection. Defaults to 0. """ - def __init__(self, socket=None, host='', port=0): + def __init__(self, host='', port=0): # The asyncio.Transport object provided by the connection_made() # callback when we are connected self.transport = None - # The socket the is used internally by the transport object + # The socket that is used internally by the transport object self.socket = None self.connect_loop_wait = 0 @@ -354,7 +354,10 @@ class XMLStream(asyncio.BaseProtocol): """ self.event(self.event_when_connected) self.transport = transport - self.socket = self.transport.get_extra_info("socket") + self.socket = self.transport.get_extra_info( + "ssl_object", + default=self.transport.get_extra_info("socket") + ) self.init_parser() self.send_raw(self.stream_header) self.dns_answers = None @@ -527,36 +530,29 @@ class XMLStream(asyncio.BaseProtocol): return self.ssl_context - def start_tls(self): + async def start_tls(self): """Perform handshakes for TLS. If the handshake is successful, the XML stream will need to be restarted. """ self.event_when_connected = "tls_success" - ssl_context = self.get_ssl_context() - ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context, - sock=self.socket, - server_hostname=self.default_domain) - async def ssl_coro(): - try: - transp, prot = await ssl_connect_routine - except ssl.SSLError as e: - log.debug('SSL: Unable to connect', exc_info=True) - log.error('CERT: Invalid certificate trust chain.') - if not self.event_handled('ssl_invalid_chain'): - self.disconnect() - else: - self.event('ssl_invalid_chain', e) + try: + transp = await self.loop.start_tls(self.transport, self, ssl_context) + except ssl.SSLError as e: + log.debug('SSL: Unable to connect', exc_info=True) + log.error('CERT: Invalid certificate trust chain.') + if not self.event_handled('ssl_invalid_chain'): + self.disconnect() else: - # Workaround for a regression in 3.4 where ssl_object was not set. - der_cert = transp.get_extra_info("ssl_object", - default=transp.get_extra_info("socket")).getpeercert(True) - pem_cert = ssl.DER_cert_to_PEM_cert(der_cert) - self.event('ssl_cert', pem_cert) - - asyncio.ensure_future(ssl_coro()) + self.event('ssl_invalid_chain', e) + return False + der_cert = transp.get_extra_info("ssl_object").getpeercert(True) + pem_cert = ssl.DER_cert_to_PEM_cert(der_cert) + self.event('ssl_cert', pem_cert) + self.connection_made(transp) + return True def _start_keepalive(self, event): """Begin sending whitespace periodically to keep the connection alive. |