src/share/classes/java/util/ParallelIterables.java
author mduigou
Wed Oct 26 09:34:13 2011 -0700 (6 months ago)
changeset 4385 38dba9d15bf0
parent 43826b4b6ae2f1c5
child 47189a0b17ea5292
permissions -rwxr-xr-x
nits and javadoc.
        1 package java.util;
        2 
        3 import java.util.concurrent.ForkJoinUtils;
        4 import java.util.concurrent.RecursiveAction;
        5 import java.util.functions.*;
        6 
        7 /**
        8  * ParallelIterables
        9  */
       10 public final class ParallelIterables {
       11     private ParallelIterables() {
       12         throw new Error("No instances for you!");
       13     }
       14 
       15     /**
       16      * Return {@code true} if the iterable contains no elements.
       17      *
       18      * @param <T> Type of elements
       19      * @param iterable The source of elements.
       20      * @return {@code true} if the Iterable contains no elements.
       21      */
       22     public static <T> boolean isEmpty(ParallelIterable<T> iterable) {
       23         Objects.requireNonNull(iterable);
       24         return !iterable.sequential().iterator().hasNext();
       25     }
       26 
       27     private static<T> int calculateDepth(long s) {
       28         long initialSize = s;
       29         long leafSize = 1 + ((s + 7) >>> 3) / ForkJoinUtils.defaultFJPool().getParallelism();
       30         int d = 0;
       31         while (s > leafSize) {
       32             s /= 2;
       33             ++d;
       34         }
       35         // System.out.printf("Size=%d, depth=%d%n", initialSize, d);
       36         return d;
       37     }
       38 
       39     private static abstract class BaseTask<T, S extends BaseTask<T, S>> extends RecursiveAction {
       40         public final int depth;
       41         public final ParallelIterable<T> coll;
       42 
       43         protected BaseTask(int depth, ParallelIterable<T> coll) {
       44             this.depth = depth;
       45             this.coll = coll;
       46         }
       47 
       48         public abstract void seq();
       49 
       50         public void combine(S left, S right) { }
       51 
       52         public abstract S makeTask(int depth, ParallelIterable<T> coll);
       53 
       54         @Override
       55         protected void compute() {
       56             if (depth == 0)
       57                 seq();
       58             else {
       59                 S left = makeTask(depth-1, coll.left());
       60                 S right = makeTask(depth-1, coll.right());
       61                 right.fork();
       62                 left.compute();
       63                 right.join();
       64                 combine(left, right);
       65             }
       66         }
       67     }
       68 
       69     private static class CountTask<T> extends BaseTask<T, CountTask<T>> {
       70         public long count;
       71 
       72         CountTask(int depth, ParallelIterable<T> coll) {
       73             super(depth, coll);
       74         }
       75 
       76         @Override
       77         public void seq() {
       78             count = Iterables.count(coll.sequential());
       79         }
       80 
       81         @Override
       82         public void combine(CountTask<T> left, CountTask<T> right) {
       83             count = left.count + right.count;
       84         }
       85 
       86         @Override
       87         public CountTask<T> makeTask(int depth, ParallelIterable<T> coll) {
       88             return new CountTask<>(depth, coll);
       89         }
       90     }
       91 
       92     public static<T> long count(ParallelIterable<T> pi) {
       93         Objects.requireNonNull(pi);
       94         if (pi instanceof Collection)
       95             return ((Collection) pi).size();
       96         else {
       97             CountTask<T> task = new CountTask<>(calculateDepth(pi.estimateCount()), pi);
       98             ForkJoinUtils.defaultFJPool().invoke(task);
       99             return task.count;
      100         }
      101     }
      102 
      103     private static class Filtered<T> implements ParallelIterable<T> {
      104         private final ParallelIterable<T> underlying;
      105         private final Predicate<T> predicate;
      106 
      107         private Filtered(ParallelIterable<T> underlying, Predicate<T> predicate) {
      108             this.underlying = underlying;
      109             this.predicate = predicate;
      110         }
      111 
      112         @Override
      113         public long estimateCount() {
      114             return underlying.estimateCount();
      115         }
      116 
      117         @Override
      118         public ParallelIterable<T> left() {
      119             return new Filtered<>(underlying.left(), predicate);
      120         }
      121 
      122         @Override
      123         public ParallelIterable<T> right() {
      124             return new Filtered<>(underlying.right(), predicate);
      125         }
      126 
      127         @Override
      128         public Iterable<T> sequential() {
      129             // @@@ Wrong!  This is a sequential traversal
      130             return underlying.sequential().filter(predicate);
      131         }
      132     }
      133 
      134     /**
      135      * Filter elements according to the provided {@code predicate} and return a
      136      * an Iterable view of the filtered elements. The filtered view will reflect
      137      * changes in the provided {@code iterable}.
      138      *
      139      * @param <T> Type of elements
      140      * @param pi The source of elements.
      141      * @param predicate Decides which elements should be included in the
      142      * resulting Iterable view. Each element with a {@code true} result will be
      143      * included in the resulting view.
      144      * @return An Iterable view of the filtered elements.
      145      */
      146     public static <T> ParallelIterable<T> filter(final ParallelIterable<T> pi, final Predicate<? super T> predicate) {
      147         Objects.requireNonNull(pi);
      148         Objects.requireNonNull(predicate);
      149         return new Filtered(pi, predicate);
      150     }
      151 
      152     private static class ForEachTask<T> extends BaseTask<T, ForEachTask<T>> {
      153         private static final long serialVersionUID = 1L;
      154         private final Block<? super T> block;
      155 
      156         ForEachTask(int depth, ParallelIterable<T> coll, Block<? super T> block) {
      157             super(depth, coll);
      158             this.block = block;
      159         }
      160 
      161         @Override
      162         public void seq() {
      163             coll.sequential().forEach(block);
      164         }
      165 
      166         @Override
      167         public ForEachTask<T> makeTask(int depth, ParallelIterable<T> coll) {
      168             return new ForEachTask<>(depth, coll, block);
      169         }
      170     }
      171 
      172     /**
      173      * Performs the operation specified by {@code block} upon each element.
      174      *
      175      * <p/>This implementation is eager and performs the operation upon elements
      176      * before returning. As such, it should be used only with finite iterables.
      177      *
      178      * @param <T> Type of elements
      179      * @param pi The source of elements.
      180      * @param block The operation to be performed upon each each element.
      181      */
      182     public static <T> void forEach(final ParallelIterable<T> pi, final Block<? super T> block) {
      183         Objects.requireNonNull(pi);
      184         Objects.requireNonNull(block);
      185         ForkJoinUtils.defaultFJPool().invoke(new ForEachTask<>(calculateDepth(pi.estimateCount()), pi, block));
      186     }
      187 
      188     private static class Mapped<T, U> implements ParallelIterable<U> {
      189         private final ParallelIterable<T> underlying;
      190         private final Mapper<? super T, ? extends U> mapper;
      191 
      192         private Mapped(ParallelIterable<T> underlying, Mapper<? super T, ? extends U> mapper) {
      193             this.underlying = underlying;
      194             this.mapper = mapper;
      195         }
      196 
      197         @Override
      198         public long estimateCount() {
      199             return underlying.estimateCount();
      200         }
      201 
      202         @Override
      203         public ParallelIterable<U> left() {
      204             return new Mapped<>(underlying.left(), mapper);
      205         }
      206 
      207         @Override
      208         public ParallelIterable<U> right() {
      209             return new Mapped<>(underlying.right(), mapper);
      210         }
      211 
      212         @Override
      213         public Iterable<U> sequential() {
      214             // @@@ Wrong!  This is a sequential traversal
      215             return underlying.sequential().map(mapper);
      216         }
      217     }
      218 
      219     /**
      220      * Map the elements of an Iterable and return an Iterable view containing
      221      * the mapped elements.
      222      *
      223      * @param <T> Type of elements
      224      * @param <U> Type of the returned elements.
      225      * @param pi The source of elements.
      226      * @param mapper Performs the mapping between elements of type {@code T}
      227      * and type {@code U}.
      228      * @return An Iterable view consisting of the mapped elements.
      229      */
      230     public static <T, U> ParallelIterable<U> map(final ParallelIterable<T> pi, final Mapper<? super T, ? extends U> mapper) {
      231         Objects.requireNonNull(pi);
      232         Objects.requireNonNull(mapper);
      233         return new Mapped<>(pi, mapper);
      234     }
      235 
      236     private static class ReduceTask<T> extends BaseTask<T, ReduceTask<T>> {
      237         private static final long serialVersionUID = 1L;
      238         public final Operator<T> operator;
      239         public final T base;
      240         public T value;
      241 
      242         ReduceTask(int depth, ParallelIterable<T> coll, T base, Operator<T> operator) {
      243             super(depth, coll);
      244             this.operator = operator;
      245             this.base = base;
      246         }
      247 
      248         @Override
      249         public void seq() {
      250             value = coll.sequential().reduce(base, operator);
      251         }
      252 
      253         @Override
      254         public void combine(ReduceTask<T> left, ReduceTask<T> right) {
      255             value = operator.eval(left.value, right.value);
      256         }
      257 
      258         @Override
      259         public ReduceTask<T> makeTask(int depth, ParallelIterable<T> coll) {
      260             return new ReduceTask<>(depth, coll, base, operator);
      261         }
      262     }
      263 
      264     private static class MapReduceTask<T, U> extends BaseTask<T, MapReduceTask<T, U>> {
      265         private static final long serialVersionUID = 1L;
      266         public final Operator<U> operator;
      267         public final Mapper<? super T,U> mapper;
      268         public final U base;
      269         public U value;
      270 
      271         MapReduceTask(int depth, ParallelIterable<T> coll, Mapper<? super T, U> mapper, U base, Operator<U> operator) {
      272             super(depth, coll);
      273             this.operator = operator;
      274             this.base = base;
      275             this.mapper = mapper;
      276         }
      277 
      278         @Override
      279         public void seq() {
      280             value = coll.sequential().mapReduce(mapper, base, operator);
      281         }
      282 
      283         @Override
      284         public void combine(MapReduceTask<T,U> left, MapReduceTask<T,U> right) {
      285             value = operator.eval(left.value, right.value);
      286         }
      287 
      288         @Override
      289         public MapReduceTask<T,U> makeTask(int depth, ParallelIterable<T> coll) {
      290             return new MapReduceTask<>(depth, coll, mapper, base, operator);
      291         }
      292     }
      293 
      294     private static class IntMapReduceTask<T> extends BaseTask<T, IntMapReduceTask<T>> {
      295         private static final long serialVersionUID = 1L;
      296         public final IntOperator operator;
      297         public final IntMapper<? super T> mapper;
      298         public final int base;
      299         public int value;
      300 
      301         IntMapReduceTask(int depth, ParallelIterable<T> coll, IntMapper<? super T> mapper, int base, IntOperator operator) {
      302             super(depth, coll);
      303             this.operator = operator;
      304             this.base = base;
      305             this.mapper = mapper;
      306         }
      307 
      308         @Override
      309         public void seq() {
      310             value = coll.sequential().mapReduce(mapper, base, operator);
      311         }
      312 
      313         @Override
      314         public void combine(IntMapReduceTask<T> left, IntMapReduceTask<T> right) {
      315             value = operator.eval(left.value, right.value);
      316         }
      317 
      318         @Override
      319         public IntMapReduceTask<T> makeTask(int depth, ParallelIterable<T> coll) {
      320             return new IntMapReduceTask<>(depth, coll, mapper, base, operator);
      321         }
      322     }
      323 
      324     private static class LongMapReduceTask<T> extends BaseTask<T, LongMapReduceTask<T>> {
      325         private static final long serialVersionUID = 1L;
      326         public final LongOperator operator;
      327         public final LongMapper<? super T> mapper;
      328         public final long base;
      329         public long value;
      330 
      331         LongMapReduceTask(int depth, ParallelIterable<T> coll, LongMapper<? super T> mapper, long base, LongOperator operator) {
      332             super(depth, coll);
      333             this.operator = operator;
      334             this.base = base;
      335             this.mapper = mapper;
      336         }
      337 
      338         @Override
      339         public void seq() {
      340             value = coll.sequential().mapReduce(mapper, base, operator);
      341         }
      342 
      343         @Override
      344         public void combine(LongMapReduceTask<T> left, LongMapReduceTask<T> right) {
      345             value = operator.eval(left.value, right.value);
      346         }
      347 
      348         @Override
      349         public LongMapReduceTask<T> makeTask(int depth, ParallelIterable<T> coll) {
      350             return new LongMapReduceTask<>(depth, coll, mapper, base, operator);
      351         }
      352     }
      353 
      354     private static class DoubleMapReduceTask<T> extends BaseTask<T, DoubleMapReduceTask<T>> {
      355         private static final long serialVersionUID = 1L;
      356         public final DoubleOperator operator;
      357         public final DoubleMapper<? super T> mapper;
      358         public final double base;
      359         public double value;
      360 
      361         DoubleMapReduceTask(int depth, ParallelIterable<T> coll, DoubleMapper<? super T> mapper, double base, DoubleOperator operator) {
      362             super(depth, coll);
      363             this.operator = operator;
      364             this.base = base;
      365             this.mapper = mapper;
      366         }
      367 
      368         @Override
      369         public void seq() {
      370             value = coll.sequential().mapReduce(mapper, base, operator);
      371         }
      372 
      373         @Override
      374         public void combine(DoubleMapReduceTask<T> left, DoubleMapReduceTask<T> right) {
      375             value = operator.eval(left.value, right.value);
      376         }
      377 
      378         @Override
      379         public DoubleMapReduceTask<T> makeTask(int depth, ParallelIterable<T> coll) {
      380             return new DoubleMapReduceTask<>(depth, coll, mapper, base, operator);
      381         }
      382     }
      383 
      384     /**
      385      * Reduce elements to a single value.
      386      *
      387      * @param <T> Type of elements
      388      * @param pi The source of elements.
      389      * @param operator Reduces elements to a result of type {@code U}.
      390      * @param base Initial value for reducer.
      391      * @return The reduced value of the elements.
      392      */
      393     public static <T> T reduce(ParallelIterable<T> pi, T base, Operator<T> operator) {
      394         Objects.requireNonNull(pi);
      395         Objects.requireNonNull(operator);
      396         ReduceTask<T> task = new ReduceTask<>(calculateDepth(pi.estimateCount()), pi, base, operator);
      397         ForkJoinUtils.defaultFJPool().invoke(task);
      398         return task.value;
      399     }
      400 
      401     public static <T, U> U mapReduce(ParallelIterable<T> pi, Mapper<? super T, U> mapper, U base, Operator<U> operator) {
      402         Objects.requireNonNull(pi);
      403         Objects.requireNonNull(mapper);
      404         Objects.requireNonNull(operator);
      405         MapReduceTask<T, U> task = new MapReduceTask<>(calculateDepth(pi.estimateCount()), pi, mapper, base, operator);
      406         ForkJoinUtils.defaultFJPool().invoke(task);
      407         return task.value;
      408     }
      409 
      410     public static<T> int mapReduce(ParallelIterable<T> pi, IntMapper<? super T> mapper, int base, IntOperator operator) {
      411         Objects.requireNonNull(pi);
      412         Objects.requireNonNull(mapper);
      413         Objects.requireNonNull(operator);
      414         IntMapReduceTask<T> task = new IntMapReduceTask<>(calculateDepth(pi.estimateCount()), pi, mapper, base, operator);
      415         ForkJoinUtils.defaultFJPool().invoke(task);
      416         return task.value;
      417     }
      418 
      419     public static<T> long mapReduce(ParallelIterable<T> pi, LongMapper<? super T> mapper, long base, LongOperator operator) {
      420         Objects.requireNonNull(pi);
      421         Objects.requireNonNull(mapper);
      422         Objects.requireNonNull(operator);
      423         LongMapReduceTask<T> task = new LongMapReduceTask<>(calculateDepth(pi.estimateCount()), pi, mapper, base, operator);
      424         ForkJoinUtils.defaultFJPool().invoke(task);
      425         return task.value;
      426     }
      427 
      428     public static<T> double mapReduce(ParallelIterable<T> pi, DoubleMapper<? super T> mapper, double base, DoubleOperator operator) {
      429         Objects.requireNonNull(pi);
      430         Objects.requireNonNull(mapper);
      431         Objects.requireNonNull(operator);
      432         DoubleMapReduceTask<T> task = new DoubleMapReduceTask<>(calculateDepth(pi.estimateCount()), pi, mapper, base, operator);
      433         ForkJoinUtils.defaultFJPool().invoke(task);
      434         return task.value;
      435     }
      436 
      437     /**
      438      * All elements of the Iterable are added to the specified container.
      439      *
      440      * @param <T> Type of elements
      441      * @param <A>
      442      * @param pi The source of elements.
      443      * @param target The collection other container into which the elements are added.
      444      * @return The provided container.
      445      */
      446     public static <T, A extends Fillable<? super T>> A into(ParallelIterable<T> pi, A target) {
      447         // Current implementation is sequential, pending more analysis on encounter-order-preservation
      448         Objects.requireNonNull(pi);
      449         Objects.requireNonNull(target);
      450         target.addAll(pi.sequential());
      451         return target;
      452     }
      453 
      454     // TODO: better short-circuiting
      455     private static class MatchTask<T> extends BaseTask<T, MatchTask<T>> {
      456         private static final long serialVersionUID = 1L;
      457         enum Kind { ANY, ALL, NONE };
      458         public final Predicate<? super T> predicate;
      459         public final Kind kind;
      460         public boolean value;
      461 
      462         MatchTask(int depth, ParallelIterable<T> coll, Predicate<? super T> predicate, Kind kind) {
      463             super(depth, coll);
      464             this.predicate = predicate;
      465             this.kind = kind;
      466         }
      467 
      468         @Override
      469         public void seq() {
      470             switch (kind) {
      471                 case ANY: value = coll.sequential().anyMatch(predicate); break;
      472                 case ALL: value = coll.sequential().allMatch(predicate); break;
      473                 case NONE: value = coll.sequential().noneMatch(predicate); break;
      474             }
      475         }
      476 
      477         @Override
      478         public void combine(MatchTask<T> left, MatchTask<T> right) {
      479             switch (kind) {
      480                 case ANY: value = left.value || right.value; break;
      481                 case ALL: value = left.value && right.value; break;
      482                 case NONE: value = left.value && right.value; break;
      483             }
      484         }
      485 
      486         @Override
      487         public MatchTask<T> makeTask(int depth, ParallelIterable<T> coll) {
      488             return new MatchTask<>(depth, coll, predicate, kind);
      489         }
      490     }
      491 
      492     public static <T> boolean anyMatch(ParallelIterable<T> pi, Predicate<? super T> predicate) {
      493         Objects.requireNonNull(pi);
      494         Objects.requireNonNull(predicate);
      495         MatchTask<T> task = new MatchTask<>(calculateDepth(pi.estimateCount()), pi, predicate, MatchTask.Kind.ANY);
      496         ForkJoinUtils.defaultFJPool().invoke(task);
      497         return task.value;
      498     }
      499 
      500     public static <T> boolean noneMatch(ParallelIterable<T> pi, Predicate<? super T> predicate) {
      501         Objects.requireNonNull(pi);
      502         Objects.requireNonNull(predicate);
      503         MatchTask<T> task = new MatchTask<>(calculateDepth(pi.estimateCount()), pi, predicate, MatchTask.Kind.NONE);
      504         ForkJoinUtils.defaultFJPool().invoke(task);
      505         return task.value;
      506     }
      507 
      508     public static <T> boolean allMatch(ParallelIterable<T> pi, Predicate<? super T> predicate) {
      509         Objects.requireNonNull(pi);
      510         Objects.requireNonNull(predicate);
      511         MatchTask<T> task = new MatchTask<>(calculateDepth(pi.estimateCount()), pi, predicate, MatchTask.Kind.ALL);
      512         ForkJoinUtils.defaultFJPool().invoke(task);
      513         return task.value;
      514     }
      515 
      516     public static<T extends Comparable<? super T>> ParallelIterable<T> sorted(ParallelIterable<T> pi) {
      517         Objects.requireNonNull(pi);
      518 
      519         throw new UnsupportedOperationException("nyi");
      520     }
      521 
      522     public static<T> ParallelIterable<T> sorted(ParallelIterable<T> pi, final Comparator<? super T> comparator) {
      523         Objects.requireNonNull(pi);
      524         Objects.requireNonNull(comparator);
      525         throw new UnsupportedOperationException("nyi");
      526     }
      527 
      528     public static<T> ParallelIterable<T> uniqueElements(ParallelIterable<T> pi) {
      529         Objects.requireNonNull(pi);
      530         throw new UnsupportedOperationException("nyi");
      531     }
      532 }