changeset 7818:5a0b40214415

assert NPE is thrown for null lambdas from forEach, replaceAll, removeIf
author akhil
date Thu, 04 Apr 2013 16:18:02 -0700
parents a79e2c6b28bb
children 5ecb2472fcbf
files test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java
diffstat 2 files changed, 116 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java	Thu Apr 04 17:40:23 2013 -0400
+++ b/test/java/util/CollectionExtensionMethods/CollectionExtensionMethodsTest.java	Thu Apr 04 16:18:02 2013 -0700
@@ -23,17 +23,19 @@
  * questions.
  */
 
+import java.util.HashSet;
+import java.util.LinkedHashSet;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Set;
 
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
 
-import java.lang.reflect.Constructor;
-import java.util.Collection;
-import java.util.Collections;
+import java.util.TreeSet;
 import java.util.function.Predicate;
 
 /**
@@ -56,12 +58,48 @@
 
     private static final int SIZE = 100;
 
+    @DataProvider(name="setProvider")
+    public static Object[][] setCases() {
+        final List<Object[]> cases = new LinkedList<>();
+        cases.add(new Object[] { new HashSet<>() });
+        cases.add(new Object[] { new LinkedHashSet<>() });
+        cases.add(new Object[] { new TreeSet<>() });
+
+        cases.add(new Object[] { new HashSet(){{add(42);}} });
+        cases.add(new Object[] { new LinkedHashSet(){{add(42);}} });
+        cases.add(new Object[] { new TreeSet(){{add(42);}} });
+        return cases.toArray(new Object[0][cases.size()]);
+    }
+
+    @Test(dataProvider = "setProvider")
+    public void testProvidedWithNull(final Set<Integer> set) throws Exception {
+        try {
+            set.forEach(null);
+            fail("expected NPE not thrown");
+        } catch (NullPointerException npe) {}
+        try {
+            set.removeIf(null);
+            fail("expected NPE not thrown");
+        } catch (NullPointerException npe) {}
+    }
+
     @Test
     public void testForEach() throws Exception {
         final CollectionSupplier supplier = new CollectionSupplier(SET_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final Set<Integer> original = ((Set<Integer>) test.original);
             final Set<Integer> set = ((Set<Integer>) test.collection);
+
+            try {
+                set.forEach(null);
+                fail("expected NPE not thrown");
+            } catch (NullPointerException npe) {}
+            if (test.className.equals("java.util.HashSet")) {
+                CollectionAsserts.assertContentsUnordered(set, original);
+            } else {
+                CollectionAsserts.assertContents(set, original);
+            }
+
             final List<Integer> actual = new LinkedList<>();
             set.forEach(actual::add);
             if (test.className.equals("java.util.HashSet")) {
@@ -80,6 +118,17 @@
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final Set<Integer> original = ((Set<Integer>) test.original);
             final Set<Integer> set = ((Set<Integer>) test.collection);
+
+            try {
+                set.removeIf(null);
+                fail("expected NPE not thrown");
+            } catch (NullPointerException npe) {}
+            if (test.className.equals("java.util.HashSet")) {
+                CollectionAsserts.assertContentsUnordered(set, original);
+            } else {
+                CollectionAsserts.assertContents(set, original);
+            }
+
             set.removeIf(pEven);
             for (int i : set) {
                 assertTrue((i % 2) == 1);
--- a/test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java	Thu Apr 04 17:40:23 2013 -0400
+++ b/test/java/util/CollectionExtensionMethods/ListExtensionMethodsTest.java	Thu Apr 04 16:18:02 2013 -0700
@@ -29,9 +29,15 @@
 import java.util.Comparators;
 import java.util.List;
 import java.util.LinkedList;
+import java.util.Stack;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.Vector;
+import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertEquals;
@@ -93,12 +99,56 @@
         }
     }
 
+    @DataProvider(name="listProvider")
+    public static Object[][] listCases() {
+        final List<Object[]> cases = new LinkedList<>();
+        cases.add(new Object[] { new ArrayList<>() });
+        cases.add(new Object[] { new LinkedList<>() });
+        cases.add(new Object[] { new Vector<>() });
+        cases.add(new Object[] { new Stack<>() });
+        cases.add(new Object[] { new CopyOnWriteArrayList<>() });
+
+        cases.add(new Object[] { new ArrayList(){{add(42);}} });
+        cases.add(new Object[] { new LinkedList(){{add(42);}} });
+        cases.add(new Object[] { new Vector(){{add(42);}} });
+        cases.add(new Object[] { new Stack(){{add(42);}} });
+        cases.add(new Object[] { new CopyOnWriteArrayList(){{add(42);}} });
+        return cases.toArray(new Object[0][cases.size()]);
+    }
+
+    @Test(dataProvider = "listProvider")
+    public void testProvidedWithNull(final List<Integer> list) throws Exception {
+        try {
+            list.forEach(null);
+            fail("expected NPE not thrown");
+        } catch (NullPointerException npe) {}
+        try {
+            list.replaceAll(null);
+            fail("expected NPE not thrown");
+        } catch (NullPointerException npe) {}
+        try {
+            list.removeIf(null);
+            fail("expected NPE not thrown");
+        } catch (NullPointerException npe) {}
+    }
+
     @Test
     public void testForEach() throws Exception {
         final CollectionSupplier supplier = new CollectionSupplier(LIST_CLASSES, SIZE);
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> original = ((List<Integer>) test.original);
             final List<Integer> list = ((List<Integer>) test.collection);
+        }
+        for (final CollectionSupplier.TestCase test : supplier.get()) {
+            final List<Integer> original = ((List<Integer>) test.original);
+            final List<Integer> list = ((List<Integer>) test.collection);
+
+            try {
+                list.forEach(null);
+                fail("expected NPE not thrown");
+            } catch (NullPointerException npe) {}
+            CollectionAsserts.assertContents(list, original);
+
             final List<Integer> actual = new LinkedList<>();
             list.forEach(actual::add);
             CollectionAsserts.assertContents(actual, list);
@@ -132,6 +182,13 @@
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> original = ((List<Integer>) test.original);
             final List<Integer> list = ((List<Integer>) test.collection);
+
+            try {
+                list.removeIf(null);
+                fail("expected NPE not thrown");
+            } catch (NullPointerException npe) {}
+            CollectionAsserts.assertContents(list, original);
+
             final AtomicInteger offset = new AtomicInteger(1);
             while (list.size() > 0) {
                 removeFirst(original, list, offset);
@@ -216,6 +273,13 @@
         for (final CollectionSupplier.TestCase test : supplier.get()) {
             final List<Integer> original = ((List<Integer>) test.original);
             final List<Integer> list = ((List<Integer>) test.collection);
+
+            try {
+                list.replaceAll(null);
+                fail("expected NPE not thrown");
+            } catch (NullPointerException npe) {}
+            CollectionAsserts.assertContents(list, original);
+
             list.replaceAll(x -> scale * x);
             for (int i=0; i < original.size(); i++) {
                 assertTrue(list.get(i) == (scale * original.get(i)), "mismatch at index " + i);