summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--slixmpp/__init__.py3
-rw-r--r--slixmpp/features/feature_starttls/starttls.py8
-rw-r--r--slixmpp/xmlstream/xmlstream.py46
3 files changed, 25 insertions, 32 deletions
diff --git a/slixmpp/__init__.py b/slixmpp/__init__.py
index 46804bf5..0730cc60 100644
--- a/slixmpp/__init__.py
+++ b/slixmpp/__init__.py
@@ -6,9 +6,6 @@
See the file LICENSE for copying permission.
"""
-import asyncio
-if hasattr(asyncio, 'sslproto'): # no ssl proto: very old asyncio = no need for this
- asyncio.sslproto._is_sslproto_available=lambda: False
import logging
logging.getLogger(__name__).addHandler(logging.NullHandler())
diff --git a/slixmpp/features/feature_starttls/starttls.py b/slixmpp/features/feature_starttls/starttls.py
index d472dad7..7e3af992 100644
--- a/slixmpp/features/feature_starttls/starttls.py
+++ b/slixmpp/features/feature_starttls/starttls.py
@@ -12,7 +12,7 @@ from slixmpp.stanza import StreamFeatures
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins import BasePlugin
from slixmpp.xmlstream.matcher import MatchXPath
-from slixmpp.xmlstream.handler import Callback
+from slixmpp.xmlstream.handler import CoroutineCallback
from slixmpp.features.feature_starttls import stanza
@@ -28,7 +28,7 @@ class FeatureSTARTTLS(BasePlugin):
def plugin_init(self):
self.xmpp.register_handler(
- Callback('STARTTLS Proceed',
+ CoroutineCallback('STARTTLS Proceed',
MatchXPath(stanza.Proceed.tag_name()),
self._handle_starttls_proceed,
instream=True))
@@ -58,8 +58,8 @@ class FeatureSTARTTLS(BasePlugin):
self.xmpp.send(features['starttls'])
return True
- def _handle_starttls_proceed(self, proceed):
+ async def _handle_starttls_proceed(self, proceed):
"""Restart the XML stream when TLS is accepted."""
log.debug("Starting TLS")
- if self.xmpp.start_tls():
+ if await self.xmpp.start_tls():
self.xmpp.features.add('starttls')
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index d5dce586..1fa07b0c 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -64,12 +64,12 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0.
"""
- def __init__(self, socket=None, host='', port=0):
+ def __init__(self, host='', port=0):
# The asyncio.Transport object provided by the connection_made()
# callback when we are connected
self.transport = None
- # The socket the is used internally by the transport object
+ # The socket that is used internally by the transport object
self.socket = None
self.connect_loop_wait = 0
@@ -354,7 +354,10 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.event(self.event_when_connected)
self.transport = transport
- self.socket = self.transport.get_extra_info("socket")
+ self.socket = self.transport.get_extra_info(
+ "ssl_object",
+ default=self.transport.get_extra_info("socket")
+ )
self.init_parser()
self.send_raw(self.stream_header)
self.dns_answers = None
@@ -527,36 +530,29 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context
- def start_tls(self):
+ async def start_tls(self):
"""Perform handshakes for TLS.
If the handshake is successful, the XML stream will need
to be restarted.
"""
self.event_when_connected = "tls_success"
-
ssl_context = self.get_ssl_context()
- ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context,
- sock=self.socket,
- server_hostname=self.default_domain)
- async def ssl_coro():
- try:
- transp, prot = await ssl_connect_routine
- except ssl.SSLError as e:
- log.debug('SSL: Unable to connect', exc_info=True)
- log.error('CERT: Invalid certificate trust chain.')
- if not self.event_handled('ssl_invalid_chain'):
- self.disconnect()
- else:
- self.event('ssl_invalid_chain', e)
+ try:
+ transp = await self.loop.start_tls(self.transport, self, ssl_context)
+ except ssl.SSLError as e:
+ log.debug('SSL: Unable to connect', exc_info=True)
+ log.error('CERT: Invalid certificate trust chain.')
+ if not self.event_handled('ssl_invalid_chain'):
+ self.disconnect()
else:
- # Workaround for a regression in 3.4 where ssl_object was not set.
- der_cert = transp.get_extra_info("ssl_object",
- default=transp.get_extra_info("socket")).getpeercert(True)
- pem_cert = ssl.DER_cert_to_PEM_cert(der_cert)
- self.event('ssl_cert', pem_cert)
-
- asyncio.ensure_future(ssl_coro())
+ self.event('ssl_invalid_chain', e)
+ return False
+ der_cert = transp.get_extra_info("ssl_object").getpeercert(True)
+ pem_cert = ssl.DER_cert_to_PEM_cert(der_cert)
+ self.event('ssl_cert', pem_cert)
+ self.connection_made(transp)
+ return True
def _start_keepalive(self, event):
"""Begin sending whitespace periodically to keep the connection alive.