From eb45e95a3073e4bfc1b7ede4faf98544ce11421a Mon Sep 17 00:00:00 2001
From: Yann Leboulanger <asterix@lagaule.org>
Date: Mon, 2 Jan 2012 16:39:06 +0100
Subject: [PATCH] check hostname in SSL certificates. Fixes #7066

---
 src/common/check_X509.py  | 164 ++++++++++++++++++++++++++++++++++++++
 src/common/connection.py  |  26 ++++--
 src/common/xmpp/tls_nb.py |   2 -
 3 files changed, 184 insertions(+), 8 deletions(-)
 create mode 100644 src/common/check_X509.py

diff --git a/src/common/check_X509.py b/src/common/check_X509.py
new file mode 100644
index 0000000000..ca896eee15
--- /dev/null
+++ b/src/common/check_X509.py
@@ -0,0 +1,164 @@
+from pyasn1.type import univ, constraint, char, namedtype, tag
+from pyasn1.codec.der.decoder import decode
+from common.helpers import prep, InvalidFormat
+
+MAX = 64
+oid_xmppaddr = '(1, 3, 6, 1, 5, 5, 7, 8, 5)'
+oid_dnssrv   = '(1, 3, 6, 1, 5, 5, 7, 8, 7)'
+
+
+
+class DirectoryString(univ.Choice):
+    componentType = namedtype.NamedTypes(
+        namedtype.NamedType(
+            'teletexString', char.TeletexString().subtype(
+                subtypeSpec=constraint.ValueSizeConstraint(1, MAX))),
+        namedtype.NamedType(
+            'printableString', char.PrintableString().subtype(
+                subtypeSpec=constraint.ValueSizeConstraint(1, MAX))),
+        namedtype.NamedType(
+            'universalString', char.UniversalString().subtype(
+                subtypeSpec=constraint.ValueSizeConstraint(1, MAX))),
+        namedtype.NamedType(
+            'utf8String', char.UTF8String().subtype(
+                subtypeSpec=constraint.ValueSizeConstraint(1, MAX))),
+        namedtype.NamedType(
+            'bmpString', char.BMPString().subtype(
+                subtypeSpec=constraint.ValueSizeConstraint(1, MAX))),
+        namedtype.NamedType(
+            'ia5String', char.IA5String().subtype(
+                subtypeSpec=constraint.ValueSizeConstraint(1, MAX))),
+        namedtype.NamedType(
+            'gString', univ.OctetString().subtype(
+                subtypeSpec=constraint.ValueSizeConstraint(1, MAX))),
+        )
+
+class AttributeValue(DirectoryString):
+    pass
+
+class AttributeType(univ.ObjectIdentifier):
+    pass
+
+class AttributeTypeAndValue(univ.Sequence):
+    componentType = namedtype.NamedTypes(
+        namedtype.NamedType('type', AttributeType()),
+        namedtype.NamedType('value', AttributeValue()),
+        )
+
+class RelativeDistinguishedName(univ.SetOf):
+    componentType = AttributeTypeAndValue()
+
+class RDNSequence(univ.SequenceOf):
+    componentType = RelativeDistinguishedName()
+
+class Name(univ.Choice):
+    componentType = namedtype.NamedTypes(
+        namedtype.NamedType('', RDNSequence()),
+        )
+
+class GeneralName(univ.Choice):
+    componentType = namedtype.NamedTypes(
+        namedtype.NamedType('otherName', univ.Sequence().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatConstructed, 0x0))),
+        namedtype.NamedType('rfc822Name', char.IA5String().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatSimple, 1))),
+        namedtype.NamedType('dNSName', char.IA5String().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatSimple, 2))),
+        namedtype.NamedType('x400Address', univ.Sequence().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatConstructed, 0x3))),
+        namedtype.NamedType('directoryName', Name().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatConstructed, 0x4))),
+        namedtype.NamedType('ediPartyName', univ.Sequence().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatConstructed, 0x5))),
+        namedtype.NamedType('uniformResourceIdentifier',
+            char.IA5String().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatSimple, 6))),
+        namedtype.NamedType('iPAddress', univ.OctetString().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatSimple, 7))),
+        namedtype.NamedType('registeredID', univ.ObjectIdentifier().subtype(
+            implicitTag=tag.Tag(tag.tagClassContext,
+            tag.tagFormatSimple, 8))),
+        )
+
+class GeneralNames(univ.SequenceOf):
+    componentType = GeneralName()
+    sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, MAX)
+
+
+#s = '0\x1a\x82\rwww.gajim.org\x82\tgajim.org'
+s = '0\x81\x86\x82\x0c*.jabber.org\x82\njabber.org\xa0\x1a\x06\x08+\x06\x01\x05\x05\x07\x08\x05\xa0\x0e\x0c\x0c*.jabber.org\xa0\x1a\x06\x08+\x06\x01\x05\x05\x07\x08\x07\xa0\x0e\x16\x0c*.jabber.org\xa0\x18\x06\x08+\x06\x01\x05\x05\x07\x08\x05\xa0\x0c\x0c\njabber.org\xa0\x18\x06\x08+\x06\x01\x05\x05\x07\x08\x07\xa0\x0c\x16\njabber.org'
+
+def _parse_asn1(asn1):
+    obj = decode(asn1, asn1Spec=GeneralNames())[0]
+    r = {}
+    for o in obj:
+        name = o.getName()
+        if name == 'dNSName':
+            if name not in r:
+                r[name] = []
+            r[name].append(str(o.getComponent()))
+        if name == 'otherName':
+            if name not in r:
+                r[name] = {}
+            tag = str(tuple(o.getComponent())[0])
+            val = str(tuple(o.getComponent())[1])
+            if tag not in r[name]:
+                r[name][tag] = []
+            r[name][tag].append(val)
+        if name == 'uniformResourceIdentifier':
+            r['uniformResourceIdentifier'] = True
+    return r
+
+def check_certificate(cert, domain):
+    cnt = cert.get_extension_count()
+    if '.' in domain:
+        compared_domain = domain.split('.', 1)[1]
+    else:
+        compared_domain = ''
+    srv_domain = '_xmpp-client.' + domain
+    compared_srv_domain = '_xmpp-client.' + compared_domain
+    for i in range(0, cnt):
+        ext = cert.get_extension(i)
+        if ext.get_short_name() == 'subjectAltName':
+            r = _parse_asn1(ext.get_data())
+            if 'otherName' in r:
+                if oid_xmppaddr in r['otherName']:
+                    for host in r['otherName'][oid_xmppaddr]:
+                        try:
+                            host = prep(None, host, None)
+                        except InvalidFormat:
+                            continue
+                        if host == domain:
+                            return True
+                if oid_dnssrv in r['otherName']:
+                    for host in r['otherName'][oid_dnssrv]:
+                        if host.startswith('_xmpp-client.*.'):
+                            if host.replace('*.', '', 1) == compared_srv_domain:
+                                return True
+                            continue
+                        if host == srv_domain:
+                            return True
+            if 'dNSName' in r:
+                for host in r['dNSName']:
+                    if host.startswith('*.'):
+                        if host[2:] == compared_domain:
+                            return True
+                        continue
+                    if host == domain:
+                        return True
+            if r:
+                return False
+            break
+
+    subject = cert.get_subject()
+    if subject.commonName == domain:
+        return True
+    return False
diff --git a/src/common/connection.py b/src/common/connection.py
index d4f49eca7f..fe4c9e6e66 100644
--- a/src/common/connection.py
+++ b/src/common/connection.py
@@ -57,6 +57,7 @@ from common import gajim
 from common import gpg
 from common import passwords
 from common import exceptions
+from common import check_X509
 from connection_handlers import *
 
 from xmpp import Smacks
@@ -97,6 +98,7 @@ ssl_error = {
 31: _("Authority and issuer serial number mismatch"),
 32: _("Key usage does not include certificate signing"),
 50: _("Application verification failure")
+#100 is for internal usage: host not correct
 }
 
 class CommonConnection:
@@ -1287,9 +1289,9 @@ class Connection(CommonConnection, ConnectionHandlers):
         except AttributeError:
             errnum = -1 # we don't have an errnum
         if errnum > 0 and str(errnum) not in gajim.config.get_per('accounts',
-        self.name, 'ignore_ssl_errors'):
-            text = _('The authenticity of the %s certificate could be invalid.') %\
-                    hostname
+        self.name, 'ignore_ssl_errors').split():
+            text = _('The authenticity of the %s certificate could be invalid.'
+                ) % hostname
             if errnum in ssl_error:
                 text += _('\nSSL Error: <b>%s</b>') % ssl_error[errnum]
             else:
@@ -1301,7 +1303,8 @@ class Connection(CommonConnection, ConnectionHandlers):
                 certificate=con.Connection.ssl_certificate))
             return True
         if hasattr(con.Connection, 'ssl_fingerprint_sha1'):
-            saved_fingerprint = gajim.config.get_per('accounts', self.name, 'ssl_fingerprint_sha1')
+            saved_fingerprint = gajim.config.get_per('accounts', self.name,
+                'ssl_fingerprint_sha1')
             if saved_fingerprint:
                 # Check sha1 fingerprint
                 if con.Connection.ssl_fingerprint_sha1 != saved_fingerprint:
@@ -1310,8 +1313,19 @@ class Connection(CommonConnection, ConnectionHandlers):
                         new_fingerprint=con.Connection.ssl_fingerprint_sha1))
                     return True
             else:
-                gajim.config.set_per('accounts', self.name, 'ssl_fingerprint_sha1',
-                        con.Connection.ssl_fingerprint_sha1)
+                gajim.config.set_per('accounts', self.name,
+                    'ssl_fingerprint_sha1', con.Connection.ssl_fingerprint_sha1)
+        if not check_X509.check_certificate(con.Connection.ssl_certificate,
+        hostname) and '100' not in gajim.config.get_per('accounts', self.name,
+        'ignore_ssl_errors').split():
+            txt = _('The authenticity of the %s certificate could be invalid.'
+                '\nThe certificate does not cover this domain.') % hostname
+            gajim.nec.push_incoming_event(SSLErrorEvent(None, conn=self,
+                error_text=txt, error_num=100, cert=con.Connection.ssl_cert_pem,
+                fingerprint=con.Connection.ssl_fingerprint_sha1,
+                certificate=con.Connection.ssl_certificate))
+            return True
+
         self._register_handlers(con, con_type)
         con.auth(
                 user=name,
diff --git a/src/common/xmpp/tls_nb.py b/src/common/xmpp/tls_nb.py
index cc8393d529..575960fca1 100644
--- a/src/common/xmpp/tls_nb.py
+++ b/src/common/xmpp/tls_nb.py
@@ -451,8 +451,6 @@ class NonBlockingTLS(PlugIn):
         try:
             self._owner.ssl_fingerprint_sha1 = cert.digest('sha1')
             self._owner.ssl_certificate = cert
-            if errnum == 0:
-                return True
             self._owner.ssl_errnum = errnum
             self._owner.ssl_cert_pem = OpenSSL.crypto.dump_certificate(
                     OpenSSL.crypto.FILETYPE_PEM, cert)
-- 
GitLab