changeset 7684:8561d74a9e8f

Better implementations of Map defaults Contributed-by: Peter Levart <peter.levart@gmail.com>
author mduigou
date Mon, 18 Mar 2013 14:10:53 -0700
parents 195249c00945
children 8aed7b1a8fd1
files src/share/classes/java/util/HashMap.java src/share/classes/java/util/Map.java src/share/classes/java/util/concurrent/ConcurrentHashMap.java src/share/classes/java/util/concurrent/ConcurrentMap.java
diffstat 4 files changed, 303 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/java/util/HashMap.java	Mon Mar 18 11:44:06 2013 -0700
+++ b/src/share/classes/java/util/HashMap.java	Mon Mar 18 14:10:53 2013 -0700
@@ -27,6 +27,8 @@
 
 import java.io.*;
 import java.util.function.Consumer;
+import java.util.function.BiFunction;
+import java.util.function.Function;
 
 /**
  * Hash table based implementation of the <tt>Map</tt> interface.  This
@@ -352,6 +354,13 @@
         return null == entry ? null : entry.getValue();
     }
 
+    @Override
+    public V getOrDefault(Object key, V defaultValue) {
+        Entry<K,V> entry = getEntry(key);
+
+        return null == entry ? defaultValue : entry.getValue();
+    }
+
     /**
      * Returns <tt>true</tt> if this map contains a mapping for the
      * specified key.
@@ -569,6 +578,235 @@
         return (e == null ? null : e.value);
     }
 
+    // optimized implementations of default methods in Map
+
+    @Override
+    public V putIfAbsent(K key, V value) {
+        int hash = (key == null) ? 0 : hash(key);
+        int i = indexFor(hash, table.length);
+        @SuppressWarnings("unchecked")
+        Entry<K,V> e = (Entry<K,V>)table[i];
+        for(; e != null; e = e.next) {
+            if (e.hash == hash && Objects.equals(e.key, key)) {
+                return e.value;
+            }
+        }
+
+        modCount++;
+        addEntry(hash, key, value, i);
+        return null;
+    }
+
+    @Override
+    public boolean remove(Object key, Object value) {
+        int hash = (key == null) ? 0 : hash(key);
+        int i = indexFor(hash, table.length);
+        @SuppressWarnings("unchecked")
+        Entry<K,V> prev = (Entry<K,V>)table[i];
+        Entry<K,V> e = prev;
+
+        while (e != null) {
+            Entry<K,V> next = e.next;
+            if (e.hash == hash && Objects.equals(e.key, key)) {
+                modCount++;
+                size--;
+                if (prev == e)
+                    table[i] = next;
+                else
+                    prev.next = next;
+                e.recordRemoval(this);
+                return true;
+            }
+            prev = e;
+            e = next;
+        }
+
+        return false;
+    }
+
+    @Override
+    public boolean replace(K key, V oldValue, V newValue) {
+        int hash = (key == null) ? 0 : hash(key);
+        int i = indexFor(hash, table.length);
+        @SuppressWarnings("unchecked")
+        Entry<K,V> e = (Entry<K,V>)table[i];
+        for(; e != null; e = e.next) {
+            if (e.hash == hash && Objects.equals(e.key, key) && Objects.equals(e.value, oldValue)) {
+                e.value = newValue;
+                e.recordAccess(this);
+                return true;
+            }
+        }
+
+        return false;
+    }
+
+    @Override
+    public V replace(K key, V value) {
+        int hash = (key == null) ? 0 : hash(key);
+        int i = indexFor(hash, table.length);
+        @SuppressWarnings("unchecked")
+        Entry<K,V> e = (Entry<K,V>)table[i];
+        for(; e != null; e = e.next) {
+            if (e.hash == hash && Objects.equals(e.key, key)) {
+                V oldValue = e.value;
+                e.value = value;
+                e.recordAccess(this);
+                return oldValue;
+            }
+        }
+
+        return null;
+    }
+
+    @Override
+    public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
+        int hash = (key == null) ? 0 : hash(key);
+        int i = indexFor(hash, table.length);
+        @SuppressWarnings("unchecked")
+        Entry<K,V> e = (Entry<K,V>)table[i];
+        for(; e != null; e = e.next) {
+            if (e.hash == hash && Objects.equals(e.key, key)) {
+                V oldValue = e.value;
+                return oldValue == null ? mappingFunction.apply(key) : oldValue;
+            }
+        }
+
+        V newValue = mappingFunction.apply(key);
+        if (newValue != null) {
+            modCount++;
+            addEntry(hash, key, newValue, i);
+        }
+
+        return newValue;
+    }
+
+    @Override
+    public V computeIfPresent(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
+        int hash = (key == null) ? 0 : hash(key);
+        int i = indexFor(hash, table.length);
+        @SuppressWarnings("unchecked")
+        Entry<K,V> prev = (Entry<K,V>)table[i];
+        Entry<K,V> e = prev;
+
+        while (e != null) {
+            Entry<K,V> next = e.next;
+            if (e.hash == hash && Objects.equals(e.key, key)) {
+                V oldValue = e.value;
+                if (oldValue == null)
+                    break;
+                V newValue = remappingFunction.apply(key, oldValue);
+                modCount++;
+                if (newValue == null) {
+                    size--;
+                    if (prev == e)
+                        table[i] = next;
+                    else
+                        prev.next = next;
+                    e.recordRemoval(this);
+                }
+                else {
+                    e.value = newValue;
+                    e.recordAccess(this);
+                }
+                return newValue;
+            }
+            prev = e;
+            e = next;
+        }
+
+        return null;
+    }
+
+    @Override
+    public V compute(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
+        int hash = (key == null) ? 0 : hash(key);
+        int i = indexFor(hash, table.length);
+        @SuppressWarnings("unchecked")
+        Entry<K,V> prev = (Entry<K,V>)table[i];
+        Entry<K,V> e = prev;
+
+        while (e != null) {
+            Entry<K,V> next = e.next;
+            if (e.hash == hash && Objects.equals(e.key, key)) {
+                V oldValue = e.value;
+                V newValue = remappingFunction.apply(key, oldValue);
+                if (oldValue != null) {
+                    modCount++;
+                    if (newValue == null) {
+                        size--;
+                        if (prev == e)
+                            table[i] = next;
+                        else
+                            prev.next = next;
+                        e.recordRemoval(this);
+                    }
+                    else {
+                        e.value = newValue;
+                        e.recordAccess(this);
+                    }
+                }
+                return newValue;
+            }
+            prev = e;
+            e = next;
+        }
+
+        V newValue = remappingFunction.apply(key, null);
+        if (newValue != null) {
+            modCount++;
+            addEntry(hash, key, newValue, i);
+        }
+
+        return newValue;
+    }
+
+    @Override
+    public V merge(K key, V value, BiFunction<? super V, ? super V, ? extends V> remappingFunction) {
+        int hash = (key == null) ? 0 : hash(key);
+        int i = indexFor(hash, table.length);
+        @SuppressWarnings("unchecked")
+        Entry<K,V> prev = (Entry<K,V>)table[i];
+        Entry<K,V> e = prev;
+
+        while (e != null) {
+            Entry<K,V> next = e.next;
+            if (e.hash == hash && Objects.equals(e.key, key)) {
+                V oldValue = e.value;
+                if (oldValue != null) {
+                    V newValue = remappingFunction.apply(oldValue, value);
+                    modCount++;
+                    if (newValue == null) {
+                        size--;
+                        if (prev == e)
+                            table[i] = next;
+                        else
+                            prev.next = next;
+                        e.recordRemoval(this);
+                    }
+                    else {
+                        e.value = newValue;
+                        e.recordAccess(this);
+                    }
+                    return newValue;
+                }
+                else
+                    return value;
+            }
+            prev = e;
+            e = next;
+        }
+
+        if (value != null) {
+            modCount++;
+            addEntry(hash, key, value, i);
+        }
+
+        return value;
+    }
+
+    // end of optimized implementations of default methods in Map
+
     /**
      * Removes and returns the entry associated with the specified key
      * in the HashMap.  Returns null if the HashMap contains no mapping
--- a/src/share/classes/java/util/Map.java	Mon Mar 18 11:44:06 2013 -0700
+++ b/src/share/classes/java/util/Map.java	Mon Mar 18 14:10:53 2013 -0700
@@ -483,6 +483,28 @@
     // Defaultable methods
 
     /**
+    *  Returns the value to which the specified key is mapped,
+    *  or {@code defaultValue} if this map contains no mapping
+    *  for the key.
+    *
+    * @param key the key whose associated value is to be returned
+    * @return the value to which the specified key is mapped, or
+    * {@code defaultValue} if this map contains no mapping for the key
+    * @throws ClassCastException if the key is of an inappropriate type for
+    * this map
+    * (<a href="Collection.html#optional-restrictions">optional</a>)
+    * @throws NullPointerException if the specified key is null and this map
+    * does not permit null keys
+    * (<a href="Collection.html#optional-restrictions">optional</a>)
+    */
+    default V getOrDefault(Object key, V defaultValue) {
+        V v;
+        return (null != (v = get(key)))
+            ? v
+            : containsKey(key) ? null : defaultValue;
+    }
+
+    /**
      * Performs the given action on each entry in this map, in the
      * order entries are returned by an entry set iterator, until all entries
      * have been processed or the action throws an {@code Exception}.
@@ -619,7 +641,7 @@
      * The default implementation is equivalent to, for this {@code map}:
      *
      * <pre> {@code
-     * if (map.containsKey(key) && map.get(key).equals(value)) {
+     * if (map.containsKey(key) && Objects.equals(map.get(key), value)) {
      *   map.remove(key);
      *   return true;
      * } else
@@ -639,7 +661,7 @@
      * @since 1.8
      */
     default boolean remove(Object key, Object value) {
-        if (!containsKey(key) || !get(key).equals(value))
+        if (!containsKey(key) || !Objects.equals(get(key), value))
             return false;
         remove(key);
         return true;
@@ -660,7 +682,7 @@
      * The default implementation is equivalent to, for this {@code map}:
      *
      * <pre> {@code
-     * if (map.containsKey(key) && map.get(key).equals(oldValue)) {
+     * if (map.containsKey(key) && Objects.equals(map.get(key), value)) {
      *   map.put(key, newValue);
      *   return true;
      * } else
@@ -681,7 +703,7 @@
      * @since 1.8
      */
     default boolean replace(K key, V oldValue, V newValue) {
-        if (!containsKey(key) || !get(key).equals(oldValue))
+        if (!containsKey(key) || !Objects.equals(get(key), oldValue))
             return false;
         put(key, newValue);
         return true;
@@ -908,15 +930,27 @@
      */
     default V compute(K key,
                       BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
+        V oldValue = get(key);
         for (;;) {
-            V oldValue = get(key);
             V newValue = remappingFunction.apply(key, oldValue);
-            if (newValue != null) {
-                if (replace(key, oldValue, newValue))
-                    return newValue;
+            if (oldValue != null) {
+                if (newValue != null) {
+                    if (replace(key, oldValue, newValue))
+                        return newValue;
+                }
+                else if (remove(key, oldValue)) {
+                    return null;
+                }
+                oldValue = get(key);
             }
-            else if (remove(key, oldValue))
-                return null;
+            else {
+                if (newValue != null) {
+                    if ((oldValue = putIfAbsent(key, newValue)) == null)
+                        return newValue;
+                }
+                else
+                    return null;
+            }
         }
     }
 
@@ -978,18 +1012,27 @@
      */
     default V merge(K key, V value,
                     BiFunction<? super V, ? super V, ? extends V> remappingFunction) {
+        V oldValue = get(key);
         for (;;) {
-            V oldValue, newValue;
-            if ((oldValue = get(key)) == null) {
-                if (value == null || putIfAbsent(key, value) == null)
-                    return value;
+            if (oldValue != null) {
+                V newValue = remappingFunction.apply(oldValue, value);
+                if (newValue != null) {
+                    if (replace(key, oldValue, newValue))
+                        return newValue;
+                }
+                else if (remove(key, oldValue)) {
+                    return null;
+                }
+                oldValue = get(key);
             }
-            else if ((newValue = remappingFunction.apply(oldValue, value)) != null) {
-                if (replace(key, oldValue, newValue))
-                    return newValue;
+            else {
+                if (value != null) {
+                    if ((oldValue = putIfAbsent(key, value)) == null)
+                        return value;
+                }
+                else
+                    return null;
             }
-            else if (remove(key, oldValue))
-                return null;
         }
     }
 }
--- a/src/share/classes/java/util/concurrent/ConcurrentHashMap.java	Mon Mar 18 11:44:06 2013 -0700
+++ b/src/share/classes/java/util/concurrent/ConcurrentHashMap.java	Mon Mar 18 14:10:53 2013 -0700
@@ -2670,7 +2670,7 @@
      * @return the mapping for the key, if present; else the defaultValue
      * @throws NullPointerException if the specified key is null
      */
-    public V getValueOrDefault(Object key, V defaultValue) {
+    public V getOrDefault(Object key, V defaultValue) {
         V v;
         return (v = internalGet(key)) == null ? defaultValue : v;
     }
--- a/src/share/classes/java/util/concurrent/ConcurrentMap.java	Mon Mar 18 11:44:06 2013 -0700
+++ b/src/share/classes/java/util/concurrent/ConcurrentMap.java	Mon Mar 18 14:10:53 2013 -0700
@@ -91,7 +91,7 @@
      * Removes the entry for a key only if currently mapped to a given value.
      * This is equivalent to
      *  <pre> {@code
-     * if (map.containsKey(key) && map.get(key).equals(value)) {
+     * if (map.containsKey(key) && Objects.equals(map.get(key), value)) {
      *   map.remove(key);
      *   return true;
      * } else
@@ -117,7 +117,7 @@
      * Replaces the entry for a key only if currently mapped to a given value.
      * This is equivalent to
      *  <pre> {@code
-     * if (map.containsKey(key) && map.get(key).equals(oldValue)) {
+     * if (map.containsKey(key) && Objects.equals(map.get(key), oldValue)) {
      *   map.put(key, newValue);
      *   return true;
      * } else