Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Parallel flatMap always sequential

Suppose I have this code:

 Collections.singletonList(10)
            .parallelStream() // .stream() - nothing changes
            .flatMap(x -> Stream.iterate(0, i -> i + 1)
                    .limit(x)
                    .parallel()
                    .peek(m -> {
                        System.out.println(Thread.currentThread().getName());
                    }))
            .collect(Collectors.toSet());

Output is the same thread name, so there is no benefit from parallel here - what I mean by that is that there is a single thread that does all the work.

Inside flatMap there is this code:

result.sequential().forEach(downstream);

I understand forcing the sequential property if the "outer" stream would be parallel (they could probably block), "outer" would have to wait for "flatMap" to finish and the other way around (since the same common pool is used) But why always force that?

Is that one of those things that could change in a later version?

like image 538
Eugene Avatar asked Jul 11 '17 14:07

Eugene


People also ask

What is difference between flatMap () and map () functions?

The difference is that the map operation produces one output value for each input value, whereas the flatMap operation produces an arbitrary number (zero or more) values for each input value.

How does a flatMap work?

In Java 8 Streams, the flatMap() method applies operation as a mapper function and provides a stream of element values. It means that in each iteration of each element the map() method creates a separate new stream. By using the flattening mechanism, it merges all streams into a single resultant stream.

Why is map different from flatMap )?

The map() method wraps the underlying sequence in a Stream instance, whereas the flatMap() method allows avoiding nested Stream<Stream<R>> structure. Here, map() produces a Stream consisting of the results of applying the toUpperCase() method to the elements of the input Stream: List<String> myList = Stream.


2 Answers

There are two different aspects.

First, there is only a single pipeline which is either sequential or parallel. The choice of sequential or parallel at the inner stream is irrelevant. Note that the downstream consumer you see in the cited code snippet represents the entire subsequent stream pipeline, so in your code, ending with .collect(Collectors.toSet());, this consumer will eventually add the resulting elements to a single Set instance which is not thread safe. So processing the inner stream in parallel with that single consumer would break the entire operation.

If an outer stream gets split, that cited code might get invoked concurrently with different consumers adding to different sets. Each of these calls would process a different element of the outer stream mapping to a different inner stream instance. Since your outer stream consists of a single element only, it can’t be split.

The way, this has been implemented, is also the reason for the Why filter() after flatMap() is “not completely” lazy in Java streams? issue, as forEach is called on the inner stream which will pass all elements to the downstream consumer. As demonstrated by this answer, an alternative implementation, supporting laziness and substream splitting, is possible. But this is a fundamentally different way of implementing it. The current design of the Stream implementation mostly works by consumer composition, so in the end, the source spliterator (and those split off from it) receives a Consumer representing the entire stream pipeline in either tryAdvance or forEachRemaining. In contrast, the solution of the linked answer does spliterator composition, producing a new Spliterator delegating to source spliterators. I supposed, both approaches have advantages and I’m not sure, how much the OpenJDK implementation would lose when working the other way round.

like image 139
Holger Avatar answered Oct 20 '22 17:10

Holger


For anyone like me, who has a dire need to parallelize flatMap and needs some practical solution, not only history and theory.

The simplest solution I came up with is to do flattening by hand, basically by replacing it with map + reduce(Stream::concat).

Here's an example to demonstrate how to do this:

@Test
void testParallelStream_NOT_WORKING() throws InterruptedException, ExecutionException {
    new ForkJoinPool(10).submit(() -> {
        Stream.iterate(0, i -> i + 1).limit(2)
                .parallel()

                // does not parallelize nested streams
                .flatMap(i -> generateRangeParallel(i, 100))

                .peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
                .forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
    }).get();
    System.out.println("done");
}

@Test
void testParallelStream_WORKING() throws InterruptedException, ExecutionException {
    new ForkJoinPool(10).submit(() -> {
        Stream.iterate(0, i -> i + 1).limit(2)
                .parallel()

                // concatenation of nested streams instead of flatMap, parallelizes ALL the items
                .map(i -> generateRangeParallel(i, 100))
                .reduce(Stream::concat).orElse(Stream.empty())

                .peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
                .forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
    }).get();
    System.out.println("done");
}

Stream<Integer> generateRangeParallel(int start, int num) {
    return Stream.iterate(start, i -> i + 1).limit(num).parallel();
}

// run this method with produced output to see how work was distributed
void countThreads(String strOut) {
    var res = Arrays.stream(strOut.split("\n"))
            .map(line -> line.split("\\s+"))
            .collect(Collectors.groupingBy(s -> s[0], Collectors.counting()));
    System.out.println(res);
    System.out.println("threads  : " + res.keySet().size());
    System.out.println("work     : " + res.values());
}

Stats from run on my machine:

NOT_WORKING case stats:
{ForkJoinPool-1-worker-23=100, ForkJoinPool-1-worker-5=300}
threads  : 2
work     : [100, 300]

WORKING case stats:
{ForkJoinPool-1-worker-9=16, ForkJoinPool-1-worker-23=20, ForkJoinPool-1-worker-21=36, ForkJoinPool-1-worker-31=17, ForkJoinPool-1-worker-27=177, ForkJoinPool-1-worker-13=17, ForkJoinPool-1-worker-5=21, ForkJoinPool-1-worker-19=8, ForkJoinPool-1-worker-17=21, ForkJoinPool-1-worker-3=67}
threads  : 10
work     : [16, 20, 36, 17, 177, 17, 21, 8, 21, 67]
like image 45
Dmytro Buryak Avatar answered Oct 20 '22 17:10

Dmytro Buryak