changeset 7283:a806c8b615b3

Cleanup in Collectors implementation of averaging; move to something more like ParallelArray.getStatistics
author briangoetz
date Tue, 05 Feb 2013 16:57:10 -0500
parents ad394630273b
children 54630475f7bd
files src/share/classes/java/util/stream/Collectors.java src/share/classes/java/util/stream/DoubleStream.java src/share/classes/java/util/stream/IntStream.java src/share/classes/java/util/stream/LongStream.java test-ng/tests/org/openjdk/tests/java/util/stream/PrimitiveSumTest.java
diffstat 5 files changed, 128 insertions(+), 109 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/java/util/stream/Collectors.java	Tue Feb 05 14:07:45 2013 -0500
+++ b/src/share/classes/java/util/stream/Collectors.java	Tue Feb 05 16:57:10 2013 -0500
@@ -40,8 +40,11 @@
 import java.util.StringJoiner;
 import java.util.function.BiConsumer;
 import java.util.function.BinaryOperator;
+import java.util.function.DoubleConsumer;
 import java.util.function.Function;
 import java.util.function.Functions;
+import java.util.function.IntConsumer;
+import java.util.function.LongConsumer;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
 
@@ -348,134 +351,152 @@
         };
     }
 
-    public static final class IntSumAsLong implements Collector.OfInt<IntSumAsLong.State> {
-        public static class State {
-            private long sum;
+    public static final class LongStatistics implements LongConsumer, IntConsumer {
+        private long count;
+        private long sum;
+        private long min;
+        private long max;
 
-            public long sum() {
-                return sum;
-            }
+        @Override
+        public void accept(int value) {
+            accept((long) value);
         }
 
         @Override
-        public State makeResult() {
-            return new State();
+        public void accept(long value) {
+            ++count;
+            sum += value;
+            min = Math.min(min, value);
+            max = Math.max(max, value);
         }
 
-        @Override
-        public void accumulateAsInt(State accumulator, int value) {
-            accumulator.sum += value;
+        private void combine(LongStatistics other) {
+            count += other.count;
+            sum += other.sum;
+            min = Math.min(min, other.min);
+            max = Math.min(max, other.max);
         }
 
-        @Override
-        public State combine(State result, State other) {
-            result.sum += other.sum;
-            return result;
+        public long getCount() {
+            return count;
+        }
+
+        public long getSum() {
+            return sum;
+        }
+
+        public long getMin() {
+            return min;
+        }
+
+        public long getMax() {
+            return max;
+        }
+
+        public OptionalDouble getMean() {
+            return count > 0 ? OptionalDouble.of((double) sum / count) : OptionalDouble.empty();
         }
     }
 
-    public static Collector.OfInt<IntSumAsLong.State> intSumAsLong() {
-        return new IntSumAsLong();
-    }
+    public static final class DoubleStatistics implements DoubleConsumer {
+        private long count;
+        private double sum;
+        private double min;
+        private double max;
 
-    public static final class IntCountAndSumAsLong implements Collector.OfInt<IntCountAndSumAsLong.State> {
-        public static class State {
-            private long count;
-            private long sum;
-
-            public OptionalDouble mean() {
-                return count > 0 ? OptionalDouble.of((double) sum / count) : OptionalDouble.empty();
-            }
+        @Override
+        public void accept(double value) {
+            ++count;
+            sum += value;
+            min = Math.min(min, value);
+            max = Math.max(max, value);
         }
 
-        @Override
-        public State makeResult() {
-            return new State();
+        private void combine(DoubleStatistics other) {
+            count += other.count;
+            sum += other.sum;
+            min = Math.min(min, other.min);
+            max = Math.min(max, other.max);
         }
 
-        @Override
-        public void accumulateAsInt(State accumulator, int value) {
-            accumulator.count++;
-            accumulator.sum += value;
+        public long getCount() {
+            return count;
         }
 
-        @Override
-        public State combine(State result, State other) {
-            result.count += other.count;
-            result.sum += other.sum;
-            return result;
+        public double getSum() {
+            return sum;
+        }
+
+        public double getMin() {
+            return min;
+        }
+
+        public double getMax() {
+            return max;
+        }
+
+        public OptionalDouble getMean() {
+            return count > 0 ? OptionalDouble.of(sum / count) : OptionalDouble.empty();
         }
     }
 
-    public static Collector.OfInt<IntCountAndSumAsLong.State> intCountAndSumAsLong() {
-        return new IntCountAndSumAsLong();
+    public static Collector.OfInt<LongStatistics> toIntStatistics() {
+        return new Collector.OfInt<LongStatistics>() {
+            @Override
+            public LongStatistics makeResult() {
+                return new LongStatistics();
+            }
+
+            @Override
+            public void accumulateAsInt(LongStatistics accumulator, int value) {
+                accumulator.accept(value);
+            }
+
+            @Override
+            public LongStatistics combine(LongStatistics result, LongStatistics other) {
+                result.combine(other);
+                return result;
+            }
+        };
     }
 
-    public static final class LongCountAndSum implements Collector.OfLong<LongCountAndSum.State> {
-        public static class State {
-            private long count;
-            private long sum;
+    public static Collector.OfLong<LongStatistics> toLongStatistics() {
+        return new Collector.OfLong<LongStatistics>() {
+            @Override
+            public LongStatistics makeResult() {
+                return new LongStatistics();
+            }
 
-            public OptionalDouble mean() {
-                return count > 0 ? OptionalDouble.of((double) sum / count) : OptionalDouble.empty();
+            @Override
+            public void accumulateAsLong(LongStatistics accumulator, long value) {
+                accumulator.accept(value);
             }
-        }
 
-        @Override
-        public State makeResult() {
-            return new State();
-        }
-
-        @Override
-        public void accumulateAsLong(State accumulator, long value) {
-            accumulator.count++;
-            accumulator.sum += value;
-        }
-
-        @Override
-        public State combine(State result, State other) {
-            result.count += other.count;
-            result.sum += other.sum;
-            return result;
-        }
+            @Override
+            public LongStatistics combine(LongStatistics result, LongStatistics other) {
+                result.combine(other);
+                return result;
+            }
+        };
     }
 
-    public static Collector.OfLong<LongCountAndSum.State> longCountAndSum() {
-        return new LongCountAndSum();
+    public static Collector.OfDouble<DoubleStatistics> toDoubleStatistics() {
+        return new Collector.OfDouble<DoubleStatistics>() {
+            @Override
+            public DoubleStatistics makeResult() {
+                return new DoubleStatistics();
+            }
+
+            @Override
+            public void accumulateAsDouble(DoubleStatistics accumulator, double value) {
+                accumulator.accept(value);
+            }
+
+            @Override
+            public DoubleStatistics combine(DoubleStatistics result, DoubleStatistics other) {
+                result.combine(other);
+                return result;
+            }
+        };
     }
-
-    // @@@ better algorithm to compensate for errors
-    public static final class DoubleCountAndSum implements Collector.OfDouble<DoubleCountAndSum.State> {
-        public static class State {
-            private long count;
-            private double sum;
-
-            public OptionalDouble mean() {
-                return count > 0 ? OptionalDouble.of(sum / count) : OptionalDouble.empty();
-            }
-        }
-
-        @Override
-        public State makeResult() {
-            return new State();
-        }
-
-        @Override
-        public void accumulateAsDouble(State accumulator, double value) {
-            accumulator.count++;
-            accumulator.sum += value;
-        }
-
-        @Override
-        public State combine(State result, State other) {
-            result.count += other.count;
-            result.sum += other.sum;
-            return result;
-        }
-    }
-
-    public static Collector.OfDouble<DoubleCountAndSum.State> doubleCountAndSum() {
-        return new DoubleCountAndSum();
-    }
-
 }
--- a/src/share/classes/java/util/stream/DoubleStream.java	Tue Feb 05 14:07:45 2013 -0500
+++ b/src/share/classes/java/util/stream/DoubleStream.java	Tue Feb 05 16:57:10 2013 -0500
@@ -118,7 +118,7 @@
     }
 
     default OptionalDouble average() {
-        return collect(Collectors.doubleCountAndSum()).mean();
+        return collect(Collectors.toDoubleStatistics()).getMean();
     }
 
     double[] toArray();
--- a/src/share/classes/java/util/stream/IntStream.java	Tue Feb 05 14:07:45 2013 -0500
+++ b/src/share/classes/java/util/stream/IntStream.java	Tue Feb 05 16:57:10 2013 -0500
@@ -36,8 +36,6 @@
 import java.util.function.IntPredicate;
 import java.util.function.IntUnaryOperator;
 import java.util.function.ObjIntConsumer;
-import java.util.logging.Level;
-import java.util.logging.Logger;
 
 public interface IntStream extends BaseStream<Integer, IntStream> {
 
@@ -109,8 +107,8 @@
 
     void forEachUntil(IntConsumer consumer, BooleanSupplier until);
 
-    default long sum() {
-        return collect(Collectors.intSumAsLong()).sum();
+    default int sum() {
+        return reduce(0, Integer::sum);
     }
 
     default OptionalInt min() {
@@ -122,7 +120,7 @@
     }
 
     default OptionalDouble average() {
-        return collect(Collectors.intCountAndSumAsLong()).mean();
+        return collect(Collectors.toIntStatistics()).getMean();
     }
 
     int[] toArray();
--- a/src/share/classes/java/util/stream/LongStream.java	Tue Feb 05 14:07:45 2013 -0500
+++ b/src/share/classes/java/util/stream/LongStream.java	Tue Feb 05 16:57:10 2013 -0500
@@ -121,7 +121,7 @@
     }
 
     default OptionalDouble average() {
-        return collect(Collectors.longCountAndSum()).mean();
+        return collect(Collectors.toLongStatistics()).getMean();
     }
 
     long[] toArray();
--- a/test-ng/tests/org/openjdk/tests/java/util/stream/PrimitiveSumTest.java	Tue Feb 05 14:07:45 2013 -0500
+++ b/test-ng/tests/org/openjdk/tests/java/util/stream/PrimitiveSumTest.java	Tue Feb 05 16:57:10 2013 -0500
@@ -35,7 +35,7 @@
         exerciseTerminalOps(data, s -> s.sum());
 
         withData(data).
-                terminal(s -> s.sum()).
+                terminal(s -> (long) s.sum()).
                 expectedResult(data.stream().longs().reduce(0, LambdaTestHelpers.lrPlus)).
                 exercise();
     }