changeset 6383:1fca62c40fb7

Add parallel version of FindFirst
author briangoetz
date Fri, 09 Nov 2012 12:47:44 -0500
parents 887edbc2572b
children b904c63a9e03
files src/share/classes/java/util/streams/ops/AbstractTask.java src/share/classes/java/util/streams/ops/FindFirstOp.java test-ng/tests/org/openjdk/tests/java/util/streams/ops/FindFirstOpTest.java test-ng/tests/org/openjdk/tests/java/util/streams/ops/MatchOpTest.java test-ng/tests/org/openjdk/tests/java/util/streams/ops/ReduceTest.java
diffstat 5 files changed, 129 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/java/util/streams/ops/AbstractTask.java	Thu Nov 08 19:12:00 2012 -0500
+++ b/src/share/classes/java/util/streams/ops/AbstractTask.java	Fri Nov 09 12:47:44 2012 -0500
@@ -118,8 +118,9 @@
                     T newChild = makeChild((i > 0) ? spliterator.split() : spliterator);
                     curChild.nextSibling = newChild;
                     curChild = newChild;
-                    newChild.fork();
                 }
+                for (T child=children.nextSibling; child != null; child=child.nextSibling)
+                    child.fork();
 
                 firstChild.compute();
             }
--- a/src/share/classes/java/util/streams/ops/FindFirstOp.java	Thu Nov 08 19:12:00 2012 -0500
+++ b/src/share/classes/java/util/streams/ops/FindFirstOp.java	Fri Nov 09 12:47:44 2012 -0500
@@ -26,7 +26,11 @@
 
 import java.util.Iterator;
 import java.util.Optional;
+import java.util.concurrent.CountedCompleter;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.streams.ParallelPipelineHelper;
 import java.util.streams.PipelineHelper;
+import java.util.streams.Spliterator;
 
 /**
  * FindFirstOp
@@ -56,10 +60,90 @@
         return iterator.hasNext() ? new Optional<>(iterator.next()) : Optional.<T>empty();
     }
 
-    // Parallel strategy
-    // - extend ComparableTask, maintain AtomicRef<Task>
-    // - decompose as normal
-    // - on find, while this <= task.get() task.cas(-> this)
-    // - before compute, compare this to task.get(), cancel if greater
-    // - onComplete combines Optional
+    @Override
+    public <P_IN> Optional<T> evaluateParallel(ParallelPipelineHelper<P_IN, T> helper) {
+        // Parallel strategy
+        // - Each node maintains a result and a cancelation flag
+        // - before compute, check cancelation flag
+        // - after successful leaf, cancel siblings that are later in the order
+
+        return helper.invoke(new FindFirstTask<>(helper));
+    }
+
+    private static class FindFirstTask<S, T> extends AbstractTask<S, T, Optional<T>, FindFirstTask<S, T>> {
+        private volatile boolean canceled = false;
+
+        private FindFirstTask(ParallelPipelineHelper<S, T> helper) {
+            super(helper);
+        }
+
+        private FindFirstTask(FindFirstTask<S, T> parent, Spliterator<S> spliterator) {
+            super(parent, spliterator);
+        }
+
+        @Override
+        protected FindFirstTask<S, T> makeChild(Spliterator<S> spliterator) {
+            return new FindFirstTask<>(this, spliterator);
+        }
+
+        @Override
+        public void compute() {
+            boolean cancel = canceled;
+            for (FindFirstTask<S, T> parent = getParent(); !cancel && parent != null; parent = parent.getParent())
+                cancel = parent.canceled;
+            if (cancel) {
+                setRawResult(Optional.<T>empty());
+                helpComplete();
+            }
+            else
+                super.compute();
+        }
+
+        @Override
+        protected Optional<T> doLeaf() {
+            Iterator<T> iterator = helper.wrapIterator(spliterator.iterator());
+            if (iterator.hasNext())
+                return new Optional<>(iterator.next());
+            else
+                return Optional.empty();
+        }
+
+        @Override
+        public void onCompletion(CountedCompleter<?> caller) {
+            if (children == null) {
+                Optional<T> result = getRawResult();
+                if (result.isPresent())
+                    cancelLaterSiblings();
+            }
+            else {
+                for (FindFirstTask<S, T> child = children; child != null; child = child.nextSibling) {
+                    Optional<T> result = child.getRawResult();
+                    setRawResult(result);
+                    if (result.isPresent()) {
+                        cancelLaterSiblings();
+                        break;
+                    }
+                }
+            }
+        }
+
+        private void cancelLaterSiblings() {
+            FindFirstTask<S, T> parent = getParent();
+            if (parent != null) {
+                boolean foundMe = false;
+                for (FindFirstTask<S, T> child = parent.children; child != null; child = child.nextSibling) {
+                    if (child == this) {
+                        foundMe = true;
+                        continue;
+                    }
+                    if (foundMe)
+                        child.canceled = true;
+                }
+                // If we are the leftmost child of the parent, then we should cancel the parent's later siblings too
+                // We could be more aggressve, and actually complete the parent here when the leftmost child completes
+                if (parent.children == this)
+                    parent.cancelLaterSiblings();
+            }
+        }
+    }
 }
--- a/test-ng/tests/org/openjdk/tests/java/util/streams/ops/FindFirstOpTest.java	Thu Nov 08 19:12:00 2012 -0500
+++ b/test-ng/tests/org/openjdk/tests/java/util/streams/ops/FindFirstOpTest.java	Fri Nov 09 12:47:44 2012 -0500
@@ -28,6 +28,8 @@
 import org.testng.annotations.Test;
 
 import java.util.Collections;
+import java.util.Optional;
+import java.util.streams.ops.FilterOp;
 import java.util.streams.ops.FindFirstOp;
 
 import static org.openjdk.tests.java.util.LambdaTestHelpers.*;
@@ -51,8 +53,37 @@
         assertEquals(2, (int) countTo(10).stream().filter(pEven).findFirst().get(), "first even number is 2");
     }
 
+    public void testFindFirstParallel() {
+        assertFalse(Collections.<Integer>emptySet().parallel().findFirst().isPresent(), "no result");
+        assertFalse(countTo(1000).parallel().filter(x -> x > 1000).findFirst().isPresent(), "no result");
+        assertEquals(2, (int) countTo(1000).parallel().filter(pEven).findFirst().get(), "first even number is 2");
+    }
+
     @Test(dataProvider = "opArrays", dataProviderClass = StreamTestDataProvider.class)
     public void testOps(String name, TestData<Integer> data) {
         exerciseOps(data, FindFirstOp.<Integer>singleton());
+        exerciseOps(data, FindFirstOp.<Integer>singleton(), new FilterOp<>(pEven));
+        exerciseOps(data, FindFirstOp.<Integer>singleton(), new FilterOp<>(pTrue));
+        exerciseOps(data, FindFirstOp.<Integer>singleton(), new FilterOp<>(pFalse));
+    }
+
+    @Test(dataProvider = "opArrays", dataProviderClass = StreamTestDataProvider.class)
+    public void testPipelines(String name, TestData<Integer> data) {
+        Optional<Integer> seq, par;
+        seq = data.seqStream().findFirst();
+        par = data.parStream().findFirst();
+        assertEquals(par, seq);
+
+        seq = data.seqStream().filter(pEven).findFirst();
+        par = data.parStream().filter(pEven).findFirst();
+        assertEquals(par, seq);
+
+        seq = data.seqStream().filter(pTrue).findFirst();
+        par = data.parStream().filter(pTrue).findFirst();
+        assertEquals(par, seq);
+
+        seq = data.seqStream().filter(pFalse).findFirst();
+        par = data.parStream().filter(pFalse).findFirst();
+        assertEquals(par, seq);
     }
 }
--- a/test-ng/tests/org/openjdk/tests/java/util/streams/ops/MatchOpTest.java	Thu Nov 08 19:12:00 2012 -0500
+++ b/test-ng/tests/org/openjdk/tests/java/util/streams/ops/MatchOpTest.java	Fri Nov 09 12:47:44 2012 -0500
@@ -30,6 +30,7 @@
 import java.util.functions.Predicate;
 import java.util.streams.Stream;
 import java.util.streams.Streamable;
+import java.util.streams.ops.FilterOp;
 import java.util.streams.ops.MatchOp;
 import java.util.streams.ops.TerminalOp;
 
@@ -79,6 +80,8 @@
         for (Predicate<Integer> p : INTEGER_PREDICATES)
             for (MatchKind matchKind : MatchKind.values()) {
                 exerciseOps(data, MatchOp.make(p, matchKind));
+                exerciseOps(data, MatchOp.make(p, matchKind), new FilterOp<>(pFalse));
+                exerciseOps(data, MatchOp.make(p, matchKind), new FilterOp<>(pEven));
             }
     }
 }
--- a/test-ng/tests/org/openjdk/tests/java/util/streams/ops/ReduceTest.java	Thu Nov 08 19:12:00 2012 -0500
+++ b/test-ng/tests/org/openjdk/tests/java/util/streams/ops/ReduceTest.java	Fri Nov 09 12:47:44 2012 -0500
@@ -29,6 +29,7 @@
 
 import java.util.List;
 import java.util.Optional;
+import java.util.streams.ops.FilterOp;
 import java.util.streams.ops.FoldOp;
 import java.util.streams.ops.MapOp;
 import java.util.streams.ops.SeedlessFoldOp;
@@ -60,6 +61,8 @@
 
     @Test(dataProvider = "opArrays", dataProviderClass = StreamTestDataProvider.class)
     public void testOps(String name, TestData<Integer> data) {
+        assertEquals(0, (int) exerciseOps(data, new FoldOp<>(() -> 0, rPlus, rPlus), new FilterOp<>(pFalse)));
+
         Optional<Integer> seedless = exerciseOps(data, new SeedlessFoldOp<>(rPlus));
         Integer folded = exerciseOps(data, new FoldOp<>(() -> 0, rPlus, rPlus));
         assertEquals(folded, seedless.orElse(0));