changeset 9286:9cd1f8626c46

Ensure spliterator traversal methods throw NPE for null consumer.
author psandoz
date Thu, 01 Aug 2013 15:28:57 +0100
parents 67384bcf46af
children beff4aea9526
files src/share/classes/java/util/stream/SpinedBuffer.java src/share/classes/java/util/stream/StreamSpliterators.java src/share/classes/java/util/stream/Streams.java test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java test/java/util/stream/bootlib/java/util/stream/SpliteratorTestHelper.java
diffstat 5 files changed, 132 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/java/util/stream/SpinedBuffer.java	Thu Aug 01 12:45:43 2013 +0100
+++ b/src/share/classes/java/util/stream/SpinedBuffer.java	Thu Aug 01 15:28:57 2013 +0100
@@ -28,6 +28,7 @@
 import java.util.Arrays;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Objects;
 import java.util.PrimitiveIterator;
 import java.util.Spliterator;
 import java.util.Spliterators;
@@ -317,6 +318,8 @@
 
             @Override
             public boolean tryAdvance(Consumer<? super E> consumer) {
+                Objects.requireNonNull(consumer);
+
                 if (splSpineIndex < lastSpineIndex
                     || (splSpineIndex == lastSpineIndex && splElementIndex < lastSpineElementFence)) {
                     consumer.accept(splChunk[splElementIndex++]);
@@ -334,6 +337,8 @@
 
             @Override
             public void forEachRemaining(Consumer<? super E> consumer) {
+                Objects.requireNonNull(consumer);
+
                 if (splSpineIndex < lastSpineIndex
                     || (splSpineIndex == lastSpineIndex && splElementIndex < lastSpineElementFence)) {
                     int i = splElementIndex;
@@ -634,6 +639,8 @@
 
             @Override
             public boolean tryAdvance(T_CONS consumer) {
+                Objects.requireNonNull(consumer);
+
                 if (splSpineIndex < lastSpineIndex
                     || (splSpineIndex == lastSpineIndex && splElementIndex < lastSpineElementFence)) {
                     arrayForOne(splChunk, splElementIndex++, consumer);
@@ -651,6 +658,8 @@
 
             @Override
             public void forEachRemaining(T_CONS consumer) {
+                Objects.requireNonNull(consumer);
+
                 if (splSpineIndex < lastSpineIndex
                     || (splSpineIndex == lastSpineIndex && splElementIndex < lastSpineElementFence)) {
                     int i = splElementIndex;
--- a/src/share/classes/java/util/stream/StreamSpliterators.java	Thu Aug 01 12:45:43 2013 +0100
+++ b/src/share/classes/java/util/stream/StreamSpliterators.java	Thu Aug 01 15:28:57 2013 +0100
@@ -25,6 +25,7 @@
 package java.util.stream;
 
 import java.util.Comparator;
+import java.util.Objects;
 import java.util.Spliterator;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.BooleanSupplier;
@@ -294,6 +295,7 @@
 
         @Override
         public boolean tryAdvance(Consumer<? super P_OUT> consumer) {
+            Objects.requireNonNull(consumer);
             boolean hasNext = doAdvance();
             if (hasNext)
                 consumer.accept(buffer.get(nextToConsume));
@@ -303,6 +305,7 @@
         @Override
         public void forEachRemaining(Consumer<? super P_OUT> consumer) {
             if (buffer == null && !finished) {
+                Objects.requireNonNull(consumer);
                 init();
 
                 ph.wrapAndCopyInto((Sink<P_OUT>) consumer::accept, spliterator);
@@ -350,6 +353,7 @@
 
         @Override
         public boolean tryAdvance(IntConsumer consumer) {
+            Objects.requireNonNull(consumer);
             boolean hasNext = doAdvance();
             if (hasNext)
                 consumer.accept(buffer.get(nextToConsume));
@@ -359,6 +363,7 @@
         @Override
         public void forEachRemaining(IntConsumer consumer) {
             if (buffer == null && !finished) {
+                Objects.requireNonNull(consumer);
                 init();
 
                 ph.wrapAndCopyInto((Sink.OfInt) consumer::accept, spliterator);
@@ -406,6 +411,7 @@
 
         @Override
         public boolean tryAdvance(LongConsumer consumer) {
+            Objects.requireNonNull(consumer);
             boolean hasNext = doAdvance();
             if (hasNext)
                 consumer.accept(buffer.get(nextToConsume));
@@ -414,6 +420,7 @@
 
         @Override
         public void forEachRemaining(LongConsumer consumer) {
+            Objects.requireNonNull(consumer);
             if (buffer == null && !finished) {
                 init();
 
@@ -462,6 +469,7 @@
 
         @Override
         public boolean tryAdvance(DoubleConsumer consumer) {
+            Objects.requireNonNull(consumer);
             boolean hasNext = doAdvance();
             if (hasNext)
                 consumer.accept(buffer.get(nextToConsume));
@@ -470,6 +478,7 @@
 
         @Override
         public void forEachRemaining(DoubleConsumer consumer) {
+            Objects.requireNonNull(consumer);
             if (buffer == null && !finished) {
                 init();
 
@@ -696,6 +705,8 @@
 
             @Override
             public boolean tryAdvance(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 if (sliceOrigin >= fence)
                     return false;
 
@@ -713,6 +724,8 @@
 
             @Override
             public void forEachRemaining(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 if (sliceOrigin >= fence)
                     return;
 
@@ -754,6 +767,8 @@
 
             @Override
             public boolean tryAdvance(T_CONS action) {
+                Objects.requireNonNull(action);
+
                 if (sliceOrigin >= fence)
                     return false;
 
@@ -771,6 +786,8 @@
 
             @Override
             public void forEachRemaining(T_CONS action) {
+                Objects.requireNonNull(action);
+
                 if (sliceOrigin >= fence)
                     return;
 
@@ -985,6 +1002,8 @@
 
             @Override
             public boolean tryAdvance(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 while (permitStatus() != PermitStatus.NO_MORE) {
                     if (!s.tryAdvance(this))
                         return false;
@@ -999,6 +1018,8 @@
 
             @Override
             public void forEachRemaining(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 ArrayBuffer.OfRef<T> sb = null;
                 PermitStatus permitStatus;
                 while ((permitStatus = permitStatus()) != PermitStatus.NO_MORE) {
@@ -1051,6 +1072,8 @@
 
             @Override
             public boolean tryAdvance(T_CONS action) {
+                Objects.requireNonNull(action);
+
                 while (permitStatus() != PermitStatus.NO_MORE) {
                     if (!s.tryAdvance((T_CONS) this))
                         return false;
@@ -1066,6 +1089,8 @@
 
             @Override
             public void forEachRemaining(T_CONS action) {
+                Objects.requireNonNull(action);
+
                 T_BUFF sb = null;
                 PermitStatus permitStatus;
                 while ((permitStatus = permitStatus()) != PermitStatus.NO_MORE) {
@@ -1237,6 +1262,8 @@
 
             @Override
             public boolean tryAdvance(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 action.accept(s.get());
                 return true;
             }
@@ -1260,6 +1287,8 @@
 
             @Override
             public boolean tryAdvance(IntConsumer action) {
+                Objects.requireNonNull(action);
+
                 action.accept(s.getAsInt());
                 return true;
             }
@@ -1283,6 +1312,8 @@
 
             @Override
             public boolean tryAdvance(LongConsumer action) {
+                Objects.requireNonNull(action);
+
                 action.accept(s.getAsLong());
                 return true;
             }
@@ -1306,6 +1337,8 @@
 
             @Override
             public boolean tryAdvance(DoubleConsumer action) {
+                Objects.requireNonNull(action);
+
                 action.accept(s.getAsDouble());
                 return true;
             }
--- a/src/share/classes/java/util/stream/Streams.java	Thu Aug 01 12:45:43 2013 +0100
+++ b/src/share/classes/java/util/stream/Streams.java	Thu Aug 01 15:28:57 2013 +0100
@@ -26,6 +26,7 @@
 
 import java.util.Comparator;
 import java.util.MayHoldCloseableResource;
+import java.util.Objects;
 import java.util.Spliterator;
 import java.util.function.Consumer;
 import java.util.function.DoubleConsumer;
@@ -81,6 +82,8 @@
 
         @Override
         public boolean tryAdvance(IntConsumer consumer) {
+            Objects.requireNonNull(consumer);
+
             final int i = from;
             if (i < upTo) {
                 from++;
@@ -97,6 +100,8 @@
 
         @Override
         public void forEachRemaining(IntConsumer consumer) {
+            Objects.requireNonNull(consumer);
+
             int i = from;
             final int hUpTo = upTo;
             int hLast = last;
@@ -200,6 +205,8 @@
 
         @Override
         public boolean tryAdvance(LongConsumer consumer) {
+            Objects.requireNonNull(consumer);
+
             final long i = from;
             if (i < upTo) {
                 from++;
@@ -216,6 +223,8 @@
 
         @Override
         public void forEachRemaining(LongConsumer consumer) {
+            Objects.requireNonNull(consumer);
+
             long i = from;
             final long hUpTo = upTo;
             int hLast = last;
@@ -389,6 +398,8 @@
 
         @Override
         public boolean tryAdvance(Consumer<? super T> action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -401,6 +412,8 @@
 
         @Override
         public void forEachRemaining(Consumer<? super T> action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -476,6 +489,8 @@
 
         @Override
         public boolean tryAdvance(IntConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -488,6 +503,8 @@
 
         @Override
         public void forEachRemaining(IntConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -563,6 +580,8 @@
 
         @Override
         public boolean tryAdvance(LongConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -575,6 +594,8 @@
 
         @Override
         public void forEachRemaining(LongConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -650,6 +671,8 @@
 
         @Override
         public boolean tryAdvance(DoubleConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -662,6 +685,8 @@
 
         @Override
         public void forEachRemaining(DoubleConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
--- a/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java	Thu Aug 01 12:45:43 2013 +0100
+++ b/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java	Thu Aug 01 15:28:57 2013 +0100
@@ -386,11 +386,23 @@
 
             db.addCollection(CopyOnWriteArraySet::new);
 
-            if (size == 1) {
+            if (size == 0) {
+                db.addCollection(c -> Collections.<Integer>emptySet());
+                db.addList(c -> Collections.<Integer>emptyList());
+            }
+            else if (size == 1) {
                 db.addCollection(c -> Collections.singleton(exp.get(0)));
                 db.addCollection(c -> Collections.singletonList(exp.get(0)));
             }
 
+            {
+                Integer[] ai = new Integer[size];
+                Arrays.fill(ai, 1);
+                db.add(String.format("Collections.nCopies(%d, 1)", exp.size()),
+                       Arrays.asList(ai),
+                       () -> Collections.nCopies(exp.size(), 1).spliterator());
+            }
+
             // Collections.synchronized/unmodifiable/checked wrappers
             db.addCollection(Collections::unmodifiableCollection);
             db.addCollection(c -> Collections.unmodifiableSet(new HashSet<>(c)));
@@ -454,6 +466,13 @@
             db.addMap(ConcurrentHashMap::new);
 
             db.addMap(ConcurrentSkipListMap::new);
+
+            if (size == 0) {
+                db.addMap(m -> Collections.<Integer, Integer>emptyMap());
+            }
+            else if (size == 1) {
+                db.addMap(m -> Collections.singletonMap(exp.get(0), exp.get(0)));
+            }
         }
 
         return spliteratorDataProvider = data.toArray(new Object[0][]);
--- a/test/java/util/stream/bootlib/java/util/stream/SpliteratorTestHelper.java	Thu Aug 01 12:45:43 2013 +0100
+++ b/test/java/util/stream/bootlib/java/util/stream/SpliteratorTestHelper.java	Thu Aug 01 15:28:57 2013 +0100
@@ -22,6 +22,8 @@
  */
 package java.util.stream;
 
+import org.testng.annotations.Test;
+
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -154,6 +156,7 @@
 
         Collection<T> exp = Collections.unmodifiableList(fromForEach);
 
+        testNullPointerException(supplier);
         testForEach(exp, supplier, boxingAdapter, asserter);
         testTryAdvance(exp, supplier, boxingAdapter, asserter);
         testMixedTryAdvanceForEach(exp, supplier, boxingAdapter, asserter);
@@ -166,6 +169,31 @@
 
     //
 
+    private static <T, S extends Spliterator<T>> void testNullPointerException(Supplier<S> s) {
+        S sp = s.get();
+        // Have to check instances and use casts to avoid tripwire messages and
+        // directly test the primitive methods
+        if (sp instanceof Spliterator.OfInt) {
+            Spliterator.OfInt psp = (Spliterator.OfInt) sp;
+            executeAndCatch(NullPointerException.class, () -> psp.forEachRemaining((IntConsumer) null));
+            executeAndCatch(NullPointerException.class, () -> psp.tryAdvance((IntConsumer) null));
+        }
+        else if (sp instanceof Spliterator.OfLong) {
+            Spliterator.OfLong psp = (Spliterator.OfLong) sp;
+            executeAndCatch(NullPointerException.class, () -> psp.forEachRemaining((LongConsumer) null));
+            executeAndCatch(NullPointerException.class, () -> psp.tryAdvance((LongConsumer) null));
+        }
+        else if (sp instanceof Spliterator.OfDouble) {
+            Spliterator.OfDouble psp = (Spliterator.OfDouble) sp;
+            executeAndCatch(NullPointerException.class, () -> psp.forEachRemaining((DoubleConsumer) null));
+            executeAndCatch(NullPointerException.class, () -> psp.tryAdvance((DoubleConsumer) null));
+        }
+        else {
+            executeAndCatch(NullPointerException.class, () -> sp.forEachRemaining(null));
+            executeAndCatch(NullPointerException.class, () -> sp.tryAdvance(null));
+        }
+    }
+
     private static <T, S extends Spliterator<T>> void testForEach(
             Collection<T> exp,
             Supplier<S> supplier,
@@ -573,6 +601,23 @@
         }
     }
 
+    private static void executeAndCatch(Class<? extends Exception> expected, Runnable r) {
+        Exception caught = null;
+        try {
+            r.run();
+        }
+        catch (Exception e) {
+            caught = e;
+        }
+
+        assertNotNull(caught,
+                      String.format("No Exception was thrown, expected an Exception of %s to be thrown",
+                                    expected.getName()));
+        assertTrue(expected.isInstance(caught),
+                   String.format("Exception thrown %s not an instance of %s",
+                                 caught.getClass().getName(), expected.getName()));
+    }
+
     static<U> void mixedTraverseAndSplit(Consumer<U> b, Spliterator<U> splTop) {
         Spliterator<U> spl1, spl2, spl3;
         splTop.tryAdvance(b);