summaryrefslogtreecommitdiff
path: root/slixmpp/xmlstream/xmlstream.py
diff options
context:
space:
mode:
Diffstat (limited to 'slixmpp/xmlstream/xmlstream.py')
-rw-r--r--slixmpp/xmlstream/xmlstream.py46
1 files changed, 21 insertions, 25 deletions
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index d5dce586..1fa07b0c 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -64,12 +64,12 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0.
"""
- def __init__(self, socket=None, host='', port=0):
+ def __init__(self, host='', port=0):
# The asyncio.Transport object provided by the connection_made()
# callback when we are connected
self.transport = None
- # The socket the is used internally by the transport object
+ # The socket that is used internally by the transport object
self.socket = None
self.connect_loop_wait = 0
@@ -354,7 +354,10 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.event(self.event_when_connected)
self.transport = transport
- self.socket = self.transport.get_extra_info("socket")
+ self.socket = self.transport.get_extra_info(
+ "ssl_object",
+ default=self.transport.get_extra_info("socket")
+ )
self.init_parser()
self.send_raw(self.stream_header)
self.dns_answers = None
@@ -527,36 +530,29 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context
- def start_tls(self):
+ async def start_tls(self):
"""Perform handshakes for TLS.
If the handshake is successful, the XML stream will need
to be restarted.
"""
self.event_when_connected = "tls_success"
-
ssl_context = self.get_ssl_context()
- ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context,
- sock=self.socket,
- server_hostname=self.default_domain)
- async def ssl_coro():
- try:
- transp, prot = await ssl_connect_routine
- except ssl.SSLError as e:
- log.debug('SSL: Unable to connect', exc_info=True)
- log.error('CERT: Invalid certificate trust chain.')
- if not self.event_handled('ssl_invalid_chain'):
- self.disconnect()
- else:
- self.event('ssl_invalid_chain', e)
+ try:
+ transp = await self.loop.start_tls(self.transport, self, ssl_context)
+ except ssl.SSLError as e:
+ log.debug('SSL: Unable to connect', exc_info=True)
+ log.error('CERT: Invalid certificate trust chain.')
+ if not self.event_handled('ssl_invalid_chain'):
+ self.disconnect()
else:
- # Workaround for a regression in 3.4 where ssl_object was not set.
- der_cert = transp.get_extra_info("ssl_object",
- default=transp.get_extra_info("socket")).getpeercert(True)
- pem_cert = ssl.DER_cert_to_PEM_cert(der_cert)
- self.event('ssl_cert', pem_cert)
-
- asyncio.ensure_future(ssl_coro())
+ self.event('ssl_invalid_chain', e)
+ return False
+ der_cert = transp.get_extra_info("ssl_object").getpeercert(True)
+ pem_cert = ssl.DER_cert_to_PEM_cert(der_cert)
+ self.event('ssl_cert', pem_cert)
+ self.connection_made(transp)
+ return True
def _start_keepalive(self, event):
"""Begin sending whitespace periodically to keep the connection alive.