changeset 13600:9729eee8ef6a

8218863: Better endpoint checks Reviewed-by: xuelei, ahgross, jnimeh, mullan, rhalade
author coffeys
date Wed, 20 Mar 2019 10:27:00 +0000
parents 2a567237233b
children e616364fb924
files src/share/classes/sun/security/ssl/SSLContextImpl.java src/share/classes/sun/security/ssl/X509TrustManagerImpl.java
diffstat 2 files changed, 53 insertions(+), 50 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/sun/security/ssl/SSLContextImpl.java	Tue Mar 19 14:16:20 2019 -0700
+++ b/src/share/classes/sun/security/ssl/SSLContextImpl.java	Wed Mar 20 10:27:00 2019 +0000
@@ -1114,8 +1114,9 @@
         checkAdditionalTrust(chain, authType, engine, false);
     }
 
-    private void checkAdditionalTrust(X509Certificate[] chain, String authType,
-                Socket socket, boolean isClient) throws CertificateException {
+    private void checkAdditionalTrust(X509Certificate[] chain,
+            String authType, Socket socket,
+            boolean checkClientTrusted) throws CertificateException {
         if (socket != null && socket.isConnected() &&
                                     socket instanceof SSLSocket) {
 
@@ -1129,9 +1130,8 @@
             String identityAlg = sslSocket.getSSLParameters().
                                         getEndpointIdentificationAlgorithm();
             if (identityAlg != null && identityAlg.length() != 0) {
-                String hostname = session.getPeerHost();
-                X509TrustManagerImpl.checkIdentity(
-                                    hostname, chain[0], identityAlg);
+                X509TrustManagerImpl.checkIdentity(session, chain,
+                                    identityAlg, checkClientTrusted);
             }
 
             // try the best to check the algorithm constraints
@@ -1155,12 +1155,13 @@
                 constraints = new SSLAlgorithmConstraints(sslSocket, true);
             }
 
-            checkAlgorithmConstraints(chain, constraints, isClient);
+            checkAlgorithmConstraints(chain, constraints, checkClientTrusted);
         }
     }
 
-    private void checkAdditionalTrust(X509Certificate[] chain, String authType,
-            SSLEngine engine, boolean isClient) throws CertificateException {
+    private void checkAdditionalTrust(X509Certificate[] chain,
+            String authType, SSLEngine engine,
+            boolean checkClientTrusted) throws CertificateException {
         if (engine != null) {
             SSLSession session = engine.getHandshakeSession();
             if (session == null) {
@@ -1171,9 +1172,8 @@
             String identityAlg = engine.getSSLParameters().
                                         getEndpointIdentificationAlgorithm();
             if (identityAlg != null && identityAlg.length() != 0) {
-                String hostname = session.getPeerHost();
-                X509TrustManagerImpl.checkIdentity(
-                                    hostname, chain[0], identityAlg);
+                X509TrustManagerImpl.checkIdentity(session, chain,
+                                    identityAlg, checkClientTrusted);
             }
 
             // try the best to check the algorithm constraints
@@ -1197,12 +1197,13 @@
                 constraints = new SSLAlgorithmConstraints(engine, true);
             }
 
-            checkAlgorithmConstraints(chain, constraints, isClient);
+            checkAlgorithmConstraints(chain, constraints, checkClientTrusted);
         }
     }
 
     private void checkAlgorithmConstraints(X509Certificate[] chain,
-            AlgorithmConstraints constraints, boolean isClient) throws CertificateException {
+            AlgorithmConstraints constraints,
+            boolean checkClientTrusted) throws CertificateException {
 
         try {
             // Does the certificate chain end with a trusted certificate?
@@ -1222,7 +1223,8 @@
             if (checkedLength >= 0) {
                 AlgorithmChecker checker =
                         new AlgorithmChecker(constraints, null,
-                                (isClient ? Validator.VAR_TLS_CLIENT : Validator.VAR_TLS_SERVER));
+                                (checkClientTrusted ? Validator.VAR_TLS_CLIENT :
+                                            Validator.VAR_TLS_SERVER));
                 checker.init(false);
                 for (int i = checkedLength; i >= 0; i--) {
                     Certificate cert = chain[i];
--- a/src/share/classes/sun/security/ssl/X509TrustManagerImpl.java	Tue Mar 19 14:16:20 2019 -0700
+++ b/src/share/classes/sun/security/ssl/X509TrustManagerImpl.java	Wed Mar 20 10:27:00 2019 +0000
@@ -145,7 +145,7 @@
     }
 
     private Validator checkTrustedInit(X509Certificate[] chain,
-                                        String authType, boolean isClient) {
+            String authType, boolean checkClientTrusted) {
         if (chain == null || chain.length == 0) {
             throw new IllegalArgumentException(
                 "null or zero-length certificate chain");
@@ -157,7 +157,7 @@
         }
 
         Validator v = null;
-        if (isClient) {
+        if (checkClientTrusted) {
             v = clientValidator;
             if (v == null) {
                 synchronized (this) {
@@ -187,9 +187,10 @@
     }
 
 
-    private void checkTrusted(X509Certificate[] chain, String authType,
-                Socket socket, boolean isClient) throws CertificateException {
-        Validator v = checkTrustedInit(chain, authType, isClient);
+    private void checkTrusted(X509Certificate[] chain,
+            String authType, Socket socket,
+            boolean checkClientTrusted) throws CertificateException {
+        Validator v = checkTrustedInit(chain, authType, checkClientTrusted);
 
         AlgorithmConstraints constraints = null;
         if ((socket != null) && socket.isConnected() &&
@@ -205,8 +206,7 @@
             String identityAlg = sslSocket.getSSLParameters().
                                         getEndpointIdentificationAlgorithm();
             if (identityAlg != null && identityAlg.length() != 0) {
-                checkIdentity(session, chain[0], identityAlg, isClient,
-                        getRequestedServerNames(socket));
+                checkIdentity(session, chain, identityAlg, checkClientTrusted);
             }
 
             // create the algorithm constraints
@@ -231,7 +231,7 @@
         }
 
         X509Certificate[] trustedChain = null;
-        if (isClient) {
+        if (checkClientTrusted) {
             trustedChain = validate(v, chain, constraints, null);
         } else {
             trustedChain = validate(v, chain, constraints, authType);
@@ -242,9 +242,10 @@
         }
     }
 
-    private void checkTrusted(X509Certificate[] chain, String authType,
-            SSLEngine engine, boolean isClient) throws CertificateException {
-        Validator v = checkTrustedInit(chain, authType, isClient);
+    private void checkTrusted(X509Certificate[] chain,
+            String authType, SSLEngine engine,
+            boolean checkClientTrusted) throws CertificateException {
+        Validator v = checkTrustedInit(chain, authType, checkClientTrusted);
 
         AlgorithmConstraints constraints = null;
         if (engine != null) {
@@ -257,8 +258,7 @@
             String identityAlg = engine.getSSLParameters().
                                         getEndpointIdentificationAlgorithm();
             if (identityAlg != null && identityAlg.length() != 0) {
-                checkIdentity(session, chain[0], identityAlg, isClient,
-                        getRequestedServerNames(engine));
+                checkIdentity(session, chain, identityAlg, checkClientTrusted);
             }
 
             // create the algorithm constraints
@@ -283,7 +283,7 @@
         }
 
         X509Certificate[] trustedChain = null;
-        if (isClient) {
+        if (checkClientTrusted) {
             trustedChain = validate(v, chain, constraints, null);
         } else {
             trustedChain = validate(v, chain, constraints, authType);
@@ -373,13 +373,8 @@
         if (socket != null && socket.isConnected() &&
                                         socket instanceof SSLSocket) {
 
-            SSLSocket sslSocket = (SSLSocket)socket;
-            SSLSession session = sslSocket.getHandshakeSession();
-
-            if (session != null && (session instanceof ExtendedSSLSession)) {
-                ExtendedSSLSession extSession = (ExtendedSSLSession)session;
-                return extSession.getRequestedServerNames();
-            }
+            return getRequestedServerNames(
+                    ((SSLSocket)socket).getHandshakeSession());
         }
 
         return Collections.<SNIServerName>emptyList();
@@ -388,12 +383,16 @@
     // Also used by X509KeyManagerImpl
     static List<SNIServerName> getRequestedServerNames(SSLEngine engine) {
         if (engine != null) {
-            SSLSession session = engine.getHandshakeSession();
+            return getRequestedServerNames(engine.getHandshakeSession());
+        }
 
-            if (session != null && (session instanceof ExtendedSSLSession)) {
-                ExtendedSSLSession extSession = (ExtendedSSLSession)session;
-                return extSession.getRequestedServerNames();
-            }
+        return Collections.<SNIServerName>emptyList();
+    }
+
+    private static List<SNIServerName> getRequestedServerNames(
+            SSLSession session) {
+        if (session != null && (session instanceof ExtendedSSLSession)) {
+            return ((ExtendedSSLSession)session).getRequestedServerNames();
         }
 
         return Collections.<SNIServerName>emptyList();
@@ -414,22 +413,23 @@
      * the identity checking aginst the server_name extension if present, and
      * may failove to peer host checking.
      */
-    private static void checkIdentity(SSLSession session,
-            X509Certificate cert,
+    static void checkIdentity(SSLSession session,
+            X509Certificate [] trustedChain,
             String algorithm,
-            boolean isClient,
-            List<SNIServerName> sniNames) throws CertificateException {
+            boolean checkClientTrusted) throws CertificateException {
 
         boolean identifiable = false;
         String peerHost = session.getPeerHost();
-        if (isClient) {
-            String hostname = getHostNameInSNI(sniNames);
-            if (hostname != null) {
+        if (!checkClientTrusted) {
+            List<SNIServerName> sniNames = getRequestedServerNames(session);
+            String sniHostName = getHostNameInSNI(sniNames);
+            if (sniHostName != null) {
                 try {
-                    checkIdentity(hostname, cert, algorithm);
+                    checkIdentity(sniHostName,
+                            trustedChain[0], algorithm);
                     identifiable = true;
                 } catch (CertificateException ce) {
-                    if (hostname.equalsIgnoreCase(peerHost)) {
+                    if (sniHostName.equalsIgnoreCase(peerHost)) {
                         throw ce;
                     }
 
@@ -439,7 +439,8 @@
         }
 
         if (!identifiable) {
-            checkIdentity(peerHost, cert, algorithm);
+            checkIdentity(peerHost,
+                    trustedChain[0], algorithm);
         }
     }