summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream
diff options
context:
space:
mode:
Diffstat (limited to 'sleekxmpp/xmlstream')
-rw-r--r--sleekxmpp/xmlstream/cert.py173
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py77
2 files changed, 238 insertions, 12 deletions
diff --git a/sleekxmpp/xmlstream/cert.py b/sleekxmpp/xmlstream/cert.py
new file mode 100644
index 00000000..0f58d4ed
--- /dev/null
+++ b/sleekxmpp/xmlstream/cert.py
@@ -0,0 +1,173 @@
+import logging
+from datetime import datetime, timedelta
+
+
+try:
+ from pyasn1.codec.der import decoder, encoder
+ from pyasn1.type.univ import Any, ObjectIdentifier, OctetString
+ from pyasn1.type.char import BMPString, IA5String, UTF8String
+ from pyasn1.type.useful import GeneralizedTime
+ from pyasn1_modules.rfc2459 import Certificate, DirectoryString, SubjectAltName, GeneralNames, GeneralName
+ from pyasn1_modules.rfc2459 import id_ce_subjectAltName as SUBJECT_ALT_NAME
+ from pyasn1_modules.rfc2459 import id_at_commonName as COMMON_NAME
+
+
+ XMPP_ADDR = ObjectIdentifier('1.3.6.1.5.5.7.8.5')
+ SRV_NAME = ObjectIdentifier('1.3.6.1.5.5.7.8.7')
+
+ HAVE_PYASN1 = True
+except ImportError:
+ HAVE_PYASN1 = False
+
+
+log = logging.getLogger(__name__)
+
+
+class CertificateError(Exception):
+ pass
+
+
+def decode_str(data):
+ encoding = 'utf-16-be' if isinstance(data, BMPString) else 'utf-8'
+ return bytes(data).decode(encoding)
+
+
+def extract_names(raw_cert):
+ results = {'CN': set(),
+ 'DNS': set(),
+ 'SRV': set(),
+ 'URI': set(),
+ 'XMPPAddr': set()}
+
+ cert = decoder.decode(raw_cert, asn1Spec=Certificate())[0]
+ tbs = cert.getComponentByName('tbsCertificate')
+ subject = tbs.getComponentByName('subject')
+ extensions = tbs.getComponentByName('extensions')
+
+ # Extract the CommonName(s) from the cert.
+ for rdnss in subject:
+ for rdns in rdnss:
+ for name in rdns:
+ oid = name.getComponentByName('type')
+ value = name.getComponentByName('value')
+
+ if oid != COMMON_NAME:
+ continue
+
+ value = decoder.decode(value, asn1Spec=DirectoryString())[0]
+ value = decode_str(value.getComponent())
+ results['CN'].add(value)
+
+ # Extract the Subject Alternate Names (DNS, SRV, URI, XMPPAddr)
+ for extension in extensions:
+ oid = extension.getComponentByName('extnID')
+ if oid != SUBJECT_ALT_NAME:
+ continue
+
+ value = decoder.decode(extension.getComponentByName('extnValue'),
+ asn1Spec=OctetString())[0]
+ sa_names = decoder.decode(value, asn1Spec=SubjectAltName())[0]
+ for name in sa_names:
+ name_type = name.getName()
+ if name_type == 'dNSName':
+ results['DNS'].add(decode_str(name.getComponent()))
+ if name_type == 'uniformResourceIdentifier':
+ value = decode_str(name.getComponent())
+ if value.startswith('xmpp:'):
+ results['URI'].add(value[5:])
+ elif name_type == 'otherName':
+ name = name.getComponent()
+
+ oid = name.getComponentByName('type-id')
+ value = name.getComponentByName('value')
+
+ if oid == XMPP_ADDR:
+ value = decoder.decode(value, asn1Spec=UTF8String())[0]
+ results['XMPPAddr'].add(decode_str(value))
+ elif oid == SRV_NAME:
+ value = decoder.decode(value, asn1Spec=IA5String())[0]
+ results['SRV'].add(decode_str(value))
+
+ return results
+
+
+def extract_dates(raw_cert):
+ if not HAVE_PYASN1:
+ log.warning("Could not find pyasn1 module. " + \
+ "SSL certificate expiration COULD NOT BE VERIFIED.")
+ return None, None
+
+ cert = decoder.decode(raw_cert, asn1Spec=Certificate())[0]
+ tbs = cert.getComponentByName('tbsCertificate')
+ validity = tbs.getComponentByName('validity')
+
+ not_before = validity.getComponentByName('notBefore')
+ not_before = str(not_before.getComponent())
+
+ not_after = validity.getComponentByName('notAfter')
+ not_after = str(not_after.getComponent())
+
+ if isinstance(not_before, GeneralizedTime):
+ not_before = datetime.strptime(not_before, '%Y%m%d%H%M%SZ')
+ else:
+ not_before = datetime.strptime(not_before, '%y%m%d%H%M%SZ')
+
+ if isinstance(not_after, GeneralizedTime):
+ not_after = datetime.strptime(not_after, '%Y%m%d%H%M%SZ')
+ else:
+ not_after = datetime.strptime(not_after, '%y%m%d%H%M%SZ')
+
+ return not_before, not_after
+
+
+def get_ttl(raw_cert):
+ not_before, not_after = extract_dates(raw_cert)
+ if not_after is None:
+ return None
+ return not_after - datetime.utcnow()
+
+
+def verify(expected, raw_cert):
+ if not HAVE_PYASN1:
+ log.warning("Could not find pyasn1 module. " + \
+ "SSL certificate COULD NOT BE VERIFIED.")
+ return
+
+ not_before, not_after = extract_dates(raw_cert)
+ cert_names = extract_names(raw_cert)
+
+ now = datetime.utcnow()
+
+ if not_before > now:
+ raise CertificateError(
+ 'Certificate has not entered its valid date range.')
+
+ if not_after <= now:
+ raise CertificateError(
+ 'Certificate has expired.')
+
+ expected_wild = expected[expected.index('.'):]
+ expected_srv = '_xmpp-client.%s' % expected
+
+ for name in cert_names['XMPPAddr']:
+ if name == expected:
+ return True
+ for name in cert_names['SRV']:
+ if name == expected_srv or name == expected:
+ return True
+ for name in cert_names['DNS']:
+ if name == expected:
+ return True
+ if name.startswith('*'):
+ name_wild = name[name.index('.'):]
+ if expected_wild == name_wild:
+ return True
+ for name in cert_names['URI']:
+ if name == expected:
+ return True
+ for name in cert_names['CN']:
+ if name == expected:
+ return True
+
+ raise CertificateError(
+ 'Could not match certficate against hostname: %s' % expected)
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index daa1af1a..56177556 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -35,7 +35,7 @@ from xml.parsers.expat import ExpatError
import sleekxmpp
from sleekxmpp.thirdparty.statemachine import StateMachine
-from sleekxmpp.xmlstream import Scheduler, tostring
+from sleekxmpp.xmlstream import Scheduler, tostring, cert
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
from sleekxmpp.xmlstream.handler import Waiter, XMLCallback
from sleekxmpp.xmlstream.matcher import MatchXMLMask
@@ -181,6 +181,9 @@ class XMLStream(object):
#: The domain to try when querying DNS records.
self.default_domain = ''
+
+ #: The expected name of the server, for validation.
+ self._expected_server_name = ''
#: The desired, or actual, address of the connected server.
self.address = (host, int(port))
@@ -313,8 +316,9 @@ class XMLStream(object):
self.dns_service = None
self.add_event_handler('connected', self._handle_connected)
- self.add_event_handler('session_start', self._start_keepalive)
self.add_event_handler('disconnected', self._end_keepalive)
+ self.add_event_handler('session_start', self._start_keepalive)
+ self.add_event_handler('session_start', self._cert_expiration)
def use_signals(self, signals=None):
"""Register signal handlers for ``SIGHUP`` and ``SIGTERM``.
@@ -500,10 +504,17 @@ class XMLStream(object):
self.socket.connect(self.address)
if self.use_ssl and self.ssl_support:
- cert = self.socket.getpeercert(binary_form=True)
- cert = ssl.DER_cert_to_PEM_cert(cert)
- log.debug('CERT: %s', cert)
- self.event('ssl_cert', cert, direct=True)
+ self._der_cert = self.socket.getpeercert(binary_form=True)
+ pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
+ log.debug('CERT: %s', pem_cert)
+
+ self.event('ssl_cert', pem_cert, direct=True)
+ try:
+ cert.verify(self._expected_server_name, self._der_cert)
+ except cert.CertificateError as err:
+ log.error(err.message)
+ self.event('ssl_invalid_cert', cert, direct=True)
+ self.disconnect(send_close=False)
self.set_socket(self.socket, ignore=True)
#this event is where you should set your application state
@@ -767,12 +778,27 @@ class XMLStream(object):
self.socket.socket = ssl_socket
else:
self.socket = ssl_socket
- self.socket.do_handshake()
- cert = self.socket.getpeercert(binary_form=True)
- cert = ssl.DER_cert_to_PEM_cert(cert)
- log.debug('CERT: %s', cert)
- self.event('ssl_cert', cert, direct=True)
+ try:
+ self.socket.do_handshake()
+ except:
+ log.error('CERT: Invalid certificate trust chain.')
+ self.event('ssl_invalid_chain', direct=True)
+ self.disconnect(self.auto_reconnect, send_close=False)
+ return False
+
+ self._der_cert = self.socket.getpeercert(binary_form=True)
+ pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
+ log.debug('CERT: %s', pem_cert)
+ self.event('ssl_cert', pem_cert, direct=True)
+
+ try:
+ cert.verify(self._expected_server_name, self._der_cert)
+ except cert.CertificateError as err:
+ log.error(err.message)
+ self.event('ssl_invalid_cert', cert, direct=True)
+ if not self.event_handled('ssl_invalid_cert'):
+ self.disconnect(self.auto_reconnect, send_close=False)
self.set_socket(self.socket)
return True
@@ -780,6 +806,26 @@ class XMLStream(object):
log.warning("Tried to enable TLS, but ssl module not found.")
return False
+ def _cert_expiration(self, event):
+ """Schedule an event for when the TLS certificate expires."""
+
+ def restart():
+ log.warn("The server certificate has expired. Restarting.")
+ self.reconnect()
+
+ cert_ttl = cert.get_ttl(self._der_cert)
+ if cert_ttl is None:
+ return
+
+ if cert_ttl.days < 0:
+ log.warn('CERT: Certificate has expired.')
+ restart()
+
+ log.info('CERT: Time until certificate expiration: %s' % cert_ttl)
+ self.schedule('Certificate Expiration',
+ cert_ttl.seconds,
+ restart)
+
def _start_keepalive(self, event):
"""Begin sending whitespace periodically to keep the connection alive.
@@ -1298,9 +1344,16 @@ class XMLStream(object):
except (Socket.error, ssl.SSLError) as serr:
self.event('socket_error', serr, direct=True)
log.error('Socket Error #%s: %s', serr.errno, serr.strerror)
+ except ValueError as e:
+ msg = e.message if hasattr(e, 'message') else e.args[0]
+
+ if 'I/O operation on closed file' in msg:
+ log.error('Can not read from closed socket.')
+ else:
+ self.exception(e)
except Exception as e:
if not self.stop.is_set():
- log.exception('Connection error.')
+ log.error('Connection error.')
self.exception(e)
if not shutdown and not self.stop.is_set() \