changeset 6411:ae4e5994cbb4

Parallel version of ReduceByOp
author briangoetz
date Thu, 15 Nov 2012 13:36:22 -0500
parents e80b94da6538
children ef5e16bf1045
files src/share/classes/java/util/streams/ReferencePipeline.java src/share/classes/java/util/streams/Stream.java src/share/classes/java/util/streams/ops/GroupByOp.java src/share/classes/java/util/streams/ops/ReduceByOp.java src/share/classes/java/util/streams/ops/SliceOp.java test-ng/tests/org/openjdk/tests/java/util/streams/ops/ReduceByOpTest.java
diffstat 6 files changed, 170 insertions(+), 94 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/java/util/streams/ReferencePipeline.java	Thu Nov 15 17:55:02 2012 +0100
+++ b/src/share/classes/java/util/streams/ReferencePipeline.java	Thu Nov 15 13:36:22 2012 -0500
@@ -130,8 +130,9 @@
     @Override
     public <K, W> Map<K, W> reduceBy(Mapper<? extends K, ? super U> classifier,
                                      Factory<W> baseFactory,
-                                     Combiner<W, W, U> reducer) {
-        return pipeline(new ReduceByOp<>(classifier, baseFactory, reducer));
+                                     Combiner<W, W, U> reducer,
+                                     BinaryOperator<W> combiner) {
+        return pipeline(new ReduceByOp<>(classifier, baseFactory, reducer, combiner));
     }
 
     @Override
--- a/src/share/classes/java/util/streams/Stream.java	Thu Nov 15 17:55:02 2012 +0100
+++ b/src/share/classes/java/util/streams/Stream.java	Thu Nov 15 13:36:22 2012 -0500
@@ -107,7 +107,8 @@
 
     <U, W> Map<U, W> reduceBy(Mapper<? extends U, ? super T> classifier,
                               Factory<W> baseFactory,
-                              Combiner<W, W, T> reducer);
+                              Combiner<W, W, T> reducer,
+                              BinaryOperator<W> combiner);
 
     T reduce(T base, BinaryOperator<T> op);
 
--- a/src/share/classes/java/util/streams/ops/GroupByOp.java	Thu Nov 15 17:55:02 2012 +0100
+++ b/src/share/classes/java/util/streams/ops/GroupByOp.java	Thu Nov 15 13:36:22 2012 -0500
@@ -68,6 +68,7 @@
     @Override
     public <S> Map<K, Collection<T>> evaluateParallel(ParallelPipelineHelper<S, T> helper) {
         if (StreamOpFlags.ORDERED.isKnown(helper.getStreamFlags())) {
+            // @@@ Should be able to use a ctor ref here, but we get a runtime failure
             return OpUtils.parallelReduce(helper, () -> new GroupBySink());
         }
         else {
--- a/src/share/classes/java/util/streams/ops/ReduceByOp.java	Thu Nov 15 17:55:02 2012 +0100
+++ b/src/share/classes/java/util/streams/ops/ReduceByOp.java	Thu Nov 15 13:36:22 2012 -0500
@@ -24,13 +24,16 @@
  */
 package java.util.streams.ops;
 
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.functions.BinaryOperator;
 import java.util.functions.Combiner;
 import java.util.functions.Factory;
 import java.util.functions.Mapper;
-import java.util.streams.PipelineHelper;
-import java.util.streams.TerminalSink;
+import java.util.streams.*;
 
 /**
  * ReduceByOp
@@ -39,15 +42,18 @@
  */
 public class ReduceByOp<T, U, W> implements TerminalOp<T,Map<U,W>> {
     private final Mapper<? extends U, ? super T> classifier;
-    private Factory<W> seedFactory;
-    private Combiner<W, W, T> reducer;
+    private final Factory<W> seedFactory;
+    private final Combiner<W, W, T> reducer;
+    private final BinaryOperator<W> combiner;
 
     public ReduceByOp(Mapper<? extends U, ? super T> classifier,
                       Factory<W> seedFactory,
-                      Combiner<W, W, T> reducer) {
+                      Combiner<W, W, T> reducer,
+                      BinaryOperator<W> combiner) {
         this.classifier = classifier;
         this.seedFactory = seedFactory;
         this.reducer = reducer;
+        this.combiner = combiner;
     }
 
     private TerminalSink<T, Map<U, W>> sink() {
@@ -84,4 +90,70 @@
     public <S> Map<U, W> evaluateSequential(PipelineHelper<S, T> helper) {
         return helper.into(sink()).getAndClearState();
     }
+
+    @Override
+    public <S> Map<U, W> evaluateParallel(ParallelPipelineHelper<S, T> helper) {
+        if (StreamOpFlags.ORDERED.isKnown(helper.getStreamFlags())) {
+            // @@@ Should be able to use a ctor ref here, but we get a runtime failure
+            return OpUtils.parallelReduce(helper, () -> new ReduceBySink());
+        }
+        else {
+            final ConcurrentHashMap<U, W> map = new ConcurrentHashMap<>();
+            final ConcurrentHashMap.Fun<? super U, ? extends W> seedFactoryAsCHMFun = (k) -> seedFactory.make();
+
+            // Cache the sink chain, so it can be reused by all F/J leaf tasks
+            Sink<S> sinkChain = helper.wrapSink(new Sink.OfValue<T>() {
+                @Override
+                public void apply(T t) {
+                    U key = classifier.map(t);
+                    W curValue = map.computeIfAbsent(key, seedFactoryAsCHMFun);
+                    while (!map.replace(key, curValue, reducer.combine(curValue, t)))
+                        curValue = map.get(key);
+                }
+            });
+
+            OpUtils.parallelForEach(helper, sinkChain);
+
+            return map;
+        }
+    }
+
+    private class ReduceBySink implements OpUtils.AccumulatingSink<T, Map<U, W>, ReduceBySink> {
+        Map<U, W> map;
+
+        @Override
+        public void begin(int size) {
+            map = new HashMap<>();
+        }
+
+        @Override
+        public void clearState() {
+            map = null;
+        }
+
+        @Override
+        public Map<U, W> getAndClearState() {
+            Map<U, W> result = map;
+            map = null;
+            return result;
+        }
+
+        @Override
+        public void apply(T t) {
+            U key = Objects.requireNonNull(classifier.map(t), String.format("The element %s cannot be mapped to a null key", t));
+            W r = map.get(key);
+            if (r == null)
+                r = seedFactory.make();
+            map.put(key, reducer.combine(r, t));
+        }
+
+        @Override
+        public void combine(ReduceBySink other) {
+            for (Map.Entry<U, W> e : other.map.entrySet()) {
+                U key = e.getKey();
+                W newValue = map.containsKey(key) ? combiner.operate(map.get(key), e.getValue()) : e.getValue();
+                map.put(key, newValue);
+            }
+        }
+    }
 }
--- a/src/share/classes/java/util/streams/ops/SliceOp.java	Thu Nov 15 17:55:02 2012 +0100
+++ b/src/share/classes/java/util/streams/ops/SliceOp.java	Thu Nov 15 13:36:22 2012 -0500
@@ -126,7 +126,7 @@
 
     @Override
     public String toString() {
-        return String.format("SliceOp[skip=%d,limit=%d", skip, limit);
+        return String.format("SliceOp[skip=%d,limit=%d]", skip, limit);
     }
 
     private static class SliceIterator<T> implements Iterator<T> {
@@ -277,88 +277,89 @@
         }
     }
 
-    private static class SizedSliceTask<S, T> extends AbstractShortCircuitTask<S, T, Node<T>, SizedSliceTask<S, T>> {
-        private final int targetOffset, targetSize;
-        private final int offset, size;
-
-        private SizedSliceTask(ParallelPipelineHelper<S, T> helper, int offset, int size) {
-            super(helper);
-            targetOffset = offset;
-            targetSize = size;
-            this.offset = 0;
-            this.size = spliterator.getSizeIfKnown();
-        }
-
-        private SizedSliceTask(SizedSliceTask<S, T> parent, Spliterator<S> spliterator) {
-            // Makes assumptions about order in which siblings are created and linked into parent!
-            super(parent, spliterator);
-            targetOffset = parent.targetOffset;
-            targetSize = parent.targetSize;
-            int siblingSizes = 0;
-            for (SizedSliceTask<S, T> sibling = parent.children; sibling != null; sibling = sibling.nextSibling)
-                siblingSizes += sibling.size;
-            size = spliterator.getSizeIfKnown();
-            offset = parent.offset + siblingSizes;
-        }
-
-        @Override
-        protected SizedSliceTask<S, T> makeChild(Spliterator<S> spliterator) {
-            return new SizedSliceTask<>(this, spliterator);
-        }
-
-        @Override
-        protected Node<T> getEmptyResult() {
-            return Nodes.emptyNode();
-        }
-
-        @Override
-        public boolean taskCancelled() {
-            if (offset > targetOffset+targetSize || offset+size < targetOffset)
-                return true;
-            else
-                return super.taskCancelled();
-        }
-
-        @Override
-        protected Node<T> doLeaf() {
-            int skipLeft = Math.max(0, targetOffset - offset);
-            int skipRight = Math.max(0, offset + size - (targetOffset + targetSize));
-            if (skipLeft == 0 && skipRight == 0)
-                return helper.into(Nodes.<T>makeBuilder(spliterator.getSizeIfKnown())).build();
-            else {
-                // If we're the first or last node that intersects the target range, peel off irrelevant elements
-                int truncatedSize = size - skipLeft - skipRight;
-                NodeBuilder<T> builder = Nodes.<T>makeBuilder(truncatedSize);
-                Sink<S> wrappedSink = helper.wrapSink(builder);
-                wrappedSink.begin(truncatedSize);
-                Iterator<S> iterator = spliterator.iterator();
-                for (int i=0; i<skipLeft; i++)
-                    iterator.next();
-                for (int i=0; i<truncatedSize; i++)
-                    wrappedSink.apply(iterator.next());
-                wrappedSink.end();
-                return builder.build();
-            }
-        }
-
-        @Override
-        public void onCompletion(CountedCompleter<?> caller) {
-            if (!isLeaf()) {
-                Node<T> result = null;
-                for (SizedSliceTask<S, T> child = children.nextSibling; child != null; child = child.nextSibling) {
-                    Node<T> childResult = child.getRawResult();
-                    if (childResult == null)
-                        continue;
-                    else if (result == null)
-                        result = childResult;
-                    else
-                        result = Nodes.node(result, childResult);
-                }
-                setRawResult(result);
-                if (offset <= targetOffset && offset+size >= targetOffset+targetSize)
-                    shortCircuit(result);
-            }
-        }
-    }
+    // @@@ Currently unused -- optimization for when all sizes are known
+//    private static class SizedSliceTask<S, T> extends AbstractShortCircuitTask<S, T, Node<T>, SizedSliceTask<S, T>> {
+//        private final int targetOffset, targetSize;
+//        private final int offset, size;
+//
+//        private SizedSliceTask(ParallelPipelineHelper<S, T> helper, int offset, int size) {
+//            super(helper);
+//            targetOffset = offset;
+//            targetSize = size;
+//            this.offset = 0;
+//            this.size = spliterator.getSizeIfKnown();
+//        }
+//
+//        private SizedSliceTask(SizedSliceTask<S, T> parent, Spliterator<S> spliterator) {
+//            // Makes assumptions about order in which siblings are created and linked into parent!
+//            super(parent, spliterator);
+//            targetOffset = parent.targetOffset;
+//            targetSize = parent.targetSize;
+//            int siblingSizes = 0;
+//            for (SizedSliceTask<S, T> sibling = parent.children; sibling != null; sibling = sibling.nextSibling)
+//                siblingSizes += sibling.size;
+//            size = spliterator.getSizeIfKnown();
+//            offset = parent.offset + siblingSizes;
+//        }
+//
+//        @Override
+//        protected SizedSliceTask<S, T> makeChild(Spliterator<S> spliterator) {
+//            return new SizedSliceTask<>(this, spliterator);
+//        }
+//
+//        @Override
+//        protected Node<T> getEmptyResult() {
+//            return Nodes.emptyNode();
+//        }
+//
+//        @Override
+//        public boolean taskCancelled() {
+//            if (offset > targetOffset+targetSize || offset+size < targetOffset)
+//                return true;
+//            else
+//                return super.taskCancelled();
+//        }
+//
+//        @Override
+//        protected Node<T> doLeaf() {
+//            int skipLeft = Math.max(0, targetOffset - offset);
+//            int skipRight = Math.max(0, offset + size - (targetOffset + targetSize));
+//            if (skipLeft == 0 && skipRight == 0)
+//                return helper.into(Nodes.<T>makeBuilder(spliterator.getSizeIfKnown())).build();
+//            else {
+//                // If we're the first or last node that intersects the target range, peel off irrelevant elements
+//                int truncatedSize = size - skipLeft - skipRight;
+//                NodeBuilder<T> builder = Nodes.<T>makeBuilder(truncatedSize);
+//                Sink<S> wrappedSink = helper.wrapSink(builder);
+//                wrappedSink.begin(truncatedSize);
+//                Iterator<S> iterator = spliterator.iterator();
+//                for (int i=0; i<skipLeft; i++)
+//                    iterator.next();
+//                for (int i=0; i<truncatedSize; i++)
+//                    wrappedSink.apply(iterator.next());
+//                wrappedSink.end();
+//                return builder.build();
+//            }
+//        }
+//
+//        @Override
+//        public void onCompletion(CountedCompleter<?> caller) {
+//            if (!isLeaf()) {
+//                Node<T> result = null;
+//                for (SizedSliceTask<S, T> child = children.nextSibling; child != null; child = child.nextSibling) {
+//                    Node<T> childResult = child.getRawResult();
+//                    if (childResult == null)
+//                        continue;
+//                    else if (result == null)
+//                        result = childResult;
+//                    else
+//                        result = Nodes.node(result, childResult);
+//                }
+//                setRawResult(result);
+//                if (offset <= targetOffset && offset+size >= targetOffset+targetSize)
+//                    shortCircuit(result);
+//            }
+//        }
+//    }
 
 }
--- a/test-ng/tests/org/openjdk/tests/java/util/streams/ops/ReduceByOpTest.java	Thu Nov 15 17:55:02 2012 +0100
+++ b/test-ng/tests/org/openjdk/tests/java/util/streams/ops/ReduceByOpTest.java	Thu Nov 15 13:36:22 2012 -0500
@@ -50,14 +50,14 @@
     public void testOps(String name, StreamTestData<Integer> data) {
         Map<Boolean,Collection<Integer>> gbResult = data.stream().groupBy(Mappers.forPredicate(pEven, true, false));
         Map<Boolean,Integer> result = data.stream().reduceBy(Mappers.forPredicate(pEven, true, false),
-                                                                () -> 0, rPlus);
+                                                                () -> 0, rPlus, rPlus);
         assertEquals(result.size(), gbResult.size());
         for (Map.Entry<Boolean, Integer> entry : result.entrySet())
             assertEquals(entry.getValue(), data.stream().filter(e -> pEven.test(e) == entry.getKey()).reduce(0, rPlus));
 
         int uniqueSize = data.into(new HashSet<Integer>()).size();
         Map<Integer, Collection<Integer>> mgResult = exerciseOps(data, new GroupByOp<>(mId));
-        Map<Integer, Integer> miResult = exerciseOps(data, new ReduceByOp<Integer, Integer, Integer>(mId, () -> 0, (w, t) -> w + 1));
+        Map<Integer, Integer> miResult = exerciseOps(data, new ReduceByOp<Integer, Integer, Integer>(mId, () -> 0, (w, t) -> w + 1, (w, u) -> w + u));
         assertEquals(miResult.keySet().size(), uniqueSize);
         for (Map.Entry<Integer, Integer> entry : miResult.entrySet())
             assertEquals((int) entry.getValue(), mgResult.get(entry.getKey()).size());