changeset 49135:2b837643a2d4 switch

Lazier SwitchBootstraps, ensure all initializations are covered by test
author redestad
date Fri, 09 Feb 2018 15:58:22 +0100
parents ac07f694a521
children d8cf1d526ccb
files src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java test/jdk/java/lang/runtime/TestSwitchBootstrap.java
diffstat 2 files changed, 133 insertions(+), 37 deletions(-) [+]
line wrap: on
line diff
--- a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java	Fri Feb 09 14:37:00 2018 +0100
+++ b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java	Fri Feb 09 15:58:22 2018 +0100
@@ -32,10 +32,11 @@
 import java.lang.invoke.MethodType;
 import java.util.Arrays;
 import java.util.Comparator;
-import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Function;
 import java.util.stream.IntStream;
 import java.util.stream.Stream;
 
@@ -54,7 +55,7 @@
 
     // Shared INIT_HOOK for all switch call sites; looks the target method up in a map
     private static final MethodHandle INIT_HOOK;
-    private static final Map<Class<?>, MethodHandle> switchMethods = new HashMap<>();
+    private static final Map<Class<?>, MethodHandle> switchMethods = new ConcurrentHashMap<>();
 
     // Types that can be handled as int switches
     private static final Set<Class<?>> INT_TYPES
@@ -65,32 +66,39 @@
     private static final Set<Class<?>> LONG_TYPES
             = Set.of(long.class, Long.class);
     private static final Set<Class<?>> DOUBLE_TYPES
-            = Set.of(long.class, Long.class);
+            = Set.of(double.class, Double.class);
 
-    private static final Comparator<String> STRING_BY_HASH
-            = Comparator.comparingInt(Objects::hashCode);
+    private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
+    private static final Function<Class<?>, MethodHandle> lookupSwitchMethod =
+            new Function<>() {
+                @Override
+                public MethodHandle apply(Class<?> c) {
+                    try {
+                        Class<?> switchClass;
+                        if (c == Enum.class)
+                            switchClass = EnumSwitchCallSite.class;
+                        else if (c == String.class)
+                            switchClass = StringSwitchCallSite.class;
+                        else if (INT_TYPES.contains(c) || FLOAT_TYPES.contains(c))
+                            switchClass = IntSwitchCallSite.class;
+                        else if (LONG_TYPES.contains(c) || DOUBLE_TYPES.contains(c))
+                            switchClass = LongSwitchCallSite.class;
+                        else
+                            throw new BootstrapMethodError("Invalid switch type: " + c);
+
+                        return LOOKUP.findVirtual(switchClass, "doSwitch",
+                                                  MethodType.methodType(int.class, c));
+                    }
+                    catch (ReflectiveOperationException e) {
+                        throw new BootstrapMethodError("Invalid switch type: " + c);
+                    }
+                }
+            };
 
     static {
         try {
-            MethodHandles.Lookup lookup = MethodHandles.lookup();
-            INIT_HOOK = lookup.findStatic(SwitchBootstraps.class, "initHook",
+            INIT_HOOK = LOOKUP.findStatic(SwitchBootstraps.class, "initHook",
                                           MethodType.methodType(MethodHandle.class, CallSite.class));
-            for (Class<?> c : INT_TYPES)
-                switchMethods.put(c, lookup.findVirtual(IntSwitchCallSite.class, "doSwitch",
-                                                        MethodType.methodType(int.class, c)));
-            for (Class<?> c : FLOAT_TYPES)
-                switchMethods.put(c, lookup.findVirtual(IntSwitchCallSite.class, "doSwitch",
-                                                        MethodType.methodType(int.class, c)));
-            for (Class<?> c : LONG_TYPES)
-                switchMethods.put(c, lookup.findVirtual(LongSwitchCallSite.class, "doSwitch",
-                                                        MethodType.methodType(int.class, c)));
-            for (Class<?> c : DOUBLE_TYPES)
-                switchMethods.put(c, lookup.findVirtual(LongSwitchCallSite.class, "doSwitch",
-                                                        MethodType.methodType(int.class, c)));
-            switchMethods.put(String.class, lookup.findVirtual(StringSwitchCallSite.class, "doSwitch",
-                                                               MethodType.methodType(int.class, String.class)));
-            switchMethods.put(Enum.class, lookup.findVirtual(EnumSwitchCallSite.class, "doSwitch",
-                                                             MethodType.methodType(int.class, Enum.class)));
         }
         catch (ReflectiveOperationException e) {
             throw new ExceptionInInitializerError(e);
@@ -98,7 +106,7 @@
     }
 
     private static<T extends CallSite> MethodHandle initHook(T receiver) {
-        return switchMethods.get(receiver.type().parameterType(0))
+        return switchMethods.computeIfAbsent(receiver.type().parameterType(0), lookupSwitchMethod)
                             .bindTo(receiver);
     }
 
@@ -406,6 +414,9 @@
     }
 
     static class StringSwitchCallSite extends ConstantCallSite {
+        private static final Comparator<String> STRING_BY_HASH
+                = Comparator.comparingInt(Objects::hashCode);
+
         private final String[] sortedByHash;
         private final int[] indexes;
         private final boolean collisions;
--- a/test/jdk/java/lang/runtime/TestSwitchBootstrap.java	Fri Feb 09 14:37:00 2018 +0100
+++ b/test/jdk/java/lang/runtime/TestSwitchBootstrap.java	Fri Feb 09 15:58:22 2018 +0100
@@ -36,6 +36,7 @@
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.Stream;
+import jdk.test.lib.RandomFactory;
 
 import org.testng.annotations.Test;
 
@@ -44,32 +45,41 @@
 
 /**
  * @test
+ * @key randomness
+ * @library /test/lib
+ * @build jdk.test.lib.RandomFactory
  * @run testng TestSwitchBootstrap
  */
 @Test
 public class TestSwitchBootstrap {
-    private final static Set<Class<?>> ALL_TYPES = Set.of(int.class, short.class, byte.class, char.class,
-                                                          Integer.class, Short.class, Byte.class, Character.class);
-    private final static Set<Class<?>> NON_BYTE_TYPES = Set.of(int.class, Integer.class, short.class, Short.class);
+    private final static Set<Class<?>> ALL_INT_TYPES = Set.of(int.class, short.class, byte.class, char.class,
+                                                              Integer.class, Short.class, Byte.class, Character.class);
+    private final static Set<Class<?>> SIGNED_NON_BYTE_TYPES = Set.of(int.class, Integer.class, short.class, Short.class);
+    private final static Set<Class<?>> CHAR_TYPES = Set.of(char.class, Character.class);
     private final static Set<Class<?>> BYTE_TYPES = Set.of(byte.class, Byte.class);
-    private final static Set<Class<?>> FLOAT_TYPES = Set.of(float.class, Float.class);
     private final static Set<Class<?>> SIGNED_TYPES
             = Set.of(int.class, short.class, byte.class,
                      Integer.class, Short.class, Byte.class);
 
     public static final MethodHandle BSM_INT_SWITCH;
+    public static final MethodHandle BSM_LONG_SWITCH;
     public static final MethodHandle BSM_FLOAT_SWITCH;
+    public static final MethodHandle BSM_DOUBLE_SWITCH;
     public static final MethodHandle BSM_STRING_SWITCH;
     public static final MethodHandle BSM_ENUM_SWITCH;
 
-    private final static Random random = new Random(System.currentTimeMillis());
+    private final static Random random = RandomFactory.getRandom();
 
     static {
         try {
             BSM_INT_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "intSwitch",
                                                                MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, int[].class));
+            BSM_LONG_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "longSwitch",
+                                                                MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, long[].class));
             BSM_FLOAT_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "floatSwitch",
-                                                               MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, float[].class));
+                                                                 MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, float[].class));
+            BSM_DOUBLE_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "doubleSwitch",
+                                                                  MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, double[].class));
             BSM_STRING_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "stringSwitch",
                                                                   MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, String[].class));
             BSM_ENUM_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "enumSwitch",
@@ -142,6 +152,14 @@
                 assertEquals(labels.length, mhs.get(Short.class).invoke((short) i));
                 assertEquals(labels.length, mhs.get(int.class).invoke(i));
                 assertEquals(labels.length, mhs.get(Integer.class).invoke(i));
+                if (i >= 0) {
+                    assertEquals(labels.length, mhs.get(char.class).invoke((char)i));
+                    assertEquals(labels.length, mhs.get(Character.class).invoke((char)i));
+                }
+                if (i >= -128 && i <= 127) {
+                    assertEquals(labels.length, mhs.get(byte.class).invoke((byte)i));
+                    assertEquals(labels.length, mhs.get(Byte.class).invoke((byte)i));
+                }
             }
         }
 
@@ -176,6 +194,56 @@
         assertEquals(-1, (int) mhs.get(Float.class).invoke(null));
     }
 
+    private void testDouble(double... labels) throws Throwable {
+        Map<Class<?>, MethodHandle> mhs
+                = Map.of(double.class, ((CallSite) BSM_DOUBLE_SWITCH.invoke(MethodHandles.lookup(), "", switchType(double.class), labels)).dynamicInvoker(),
+                         Double.class, ((CallSite) BSM_DOUBLE_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Double.class), labels)).dynamicInvoker());
+
+        var labelList = new ArrayList<Double>();
+        for (double label : labels)
+            labelList.add(label);
+
+        for (int i=0; i<labels.length; i++) {
+            assertEquals(i, (int) mhs.get(double.class).invokeExact((double) labels[i]));
+            assertEquals(i, (int) mhs.get(Double.class).invokeExact((Double) labels[i]));
+        }
+
+        double[] someDoubles = { 1.0, Double.MIN_VALUE, 3.14 };
+        for (double f : someDoubles) {
+            if (!labelList.contains(f)) {
+                assertEquals(labels.length, mhs.get(double.class).invoke((double) f));
+                assertEquals(labels.length, mhs.get(Double.class).invoke((double) f));
+            }
+        }
+
+        assertEquals(-1, (int) mhs.get(Double.class).invoke(null));
+    }
+
+    private void testLong(long... labels) throws Throwable {
+        Map<Class<?>, MethodHandle> mhs
+                = Map.of(long.class, ((CallSite) BSM_LONG_SWITCH.invoke(MethodHandles.lookup(), "", switchType(long.class), labels)).dynamicInvoker(),
+                         Long.class, ((CallSite) BSM_LONG_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Long.class), labels)).dynamicInvoker());
+
+        List<Long> labelList = new ArrayList<>();
+        for (long label : labels)
+            labelList.add(label);
+
+        for (int i=0; i<labels.length; i++) {
+            assertEquals(i, (int) mhs.get(long.class).invokeExact((long) labels[i]));
+            assertEquals(i, (int) mhs.get(Long.class).invokeExact((Long) labels[i]));
+        }
+
+        long[] someLongs = { 1L, Long.MIN_VALUE, Long.MAX_VALUE };
+        for (long l : someLongs) {
+            if (!labelList.contains(l)) {
+                assertEquals(labels.length, mhs.get(long.class).invoke((long) l));
+                assertEquals(labels.length, mhs.get(Long.class).invoke((long) l));
+            }
+        }
+
+        assertEquals(-1, (int) mhs.get(Long.class).invoke(null));
+    }
+
     private void testString(String... targets) throws Throwable {
         MethodHandle indy = ((CallSite) BSM_STRING_SWITCH.invoke(MethodHandles.lookup(), "", switchType(String.class), targets)).dynamicInvoker();
         List<String> targetList = Stream.of(targets)
@@ -220,12 +288,12 @@
     }
 
     public void testInt() throws Throwable {
-        testInt(ALL_TYPES, 8, 6, 7, 5, 3, 0, 9);
-        testInt(ALL_TYPES, 1, 2, 4, 8, 16);
-        testInt(ALL_TYPES, 5, 4, 3, 2, 1, 0);
+        testInt(ALL_INT_TYPES, 8, 6, 7, 5, 3, 0, 9);
+        testInt(ALL_INT_TYPES, 1, 2, 4, 8, 16);
+        testInt(ALL_INT_TYPES, 5, 4, 3, 2, 1, 0);
         testInt(SIGNED_TYPES, 5, 4, 3, 2, 1, 0, -1);
         testInt(SIGNED_TYPES, -1);
-        testInt(ALL_TYPES, new int[] { });
+        testInt(ALL_INT_TYPES, new int[] { });
 
         for (int i=0; i<5; i++) {
             int len = 50 + random.nextInt(800);
@@ -233,7 +301,13 @@
                                  .distinct()
                                  .limit(len)
                                  .toArray();
-            testInt(NON_BYTE_TYPES, arr);
+            testInt(SIGNED_NON_BYTE_TYPES, arr);
+
+            arr = IntStream.generate(() -> random.nextInt(10000))
+                    .distinct()
+                    .limit(len)
+                    .toArray();
+            testInt(CHAR_TYPES, arr);
 
             arr = IntStream.generate(() -> random.nextInt(127) - 64)
                            .distinct()
@@ -244,7 +318,12 @@
     }
 
     public void testLong() throws Throwable {
-        // @@@
+        testLong(1L, Long.MIN_VALUE, Long.MAX_VALUE);
+        testLong(8L, 2L, 5L, 4L, 3L, 9L, 1L);
+        testLong(new long[] { });
+
+        // @@@ Random tests
+        // @@@ More tests for weird values
     }
 
     public void testFloat() throws Throwable {
@@ -257,7 +336,13 @@
     }
 
     public void testDouble() throws Throwable {
-        // @@@
+        testDouble(0.0, -0.0, -1.0, 1.0, 3.14, Double.MIN_VALUE, Double.MAX_VALUE,
+                   Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY);
+        testDouble(new double[] { });
+        testDouble(0.0f, 1.0f, 3.14f, Double.NaN);
+
+        // @@@ Random tests
+        // @@@ More tests for weird values
     }
 
     public void testString() throws Throwable {