Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to lazily evaluate nested flatMap

I'm trying to conjure up a cartesian product from two potentially infinite streams that I then limit via limit().

This has been (approximately) my strategy so far:

@Test
void flatMapIsLazy() {
        Stream.of("a", "b", "c")
            .flatMap(s -> Stream.of("x", "y")
                .flatMap(sd -> IntStream.rangeClosed(0, Integer.MAX_VALUE)
                    .mapToObj(sd::repeat)))
            .map(s -> s + "u")
            .limit(20)
            .forEach(System.out::println);
}

This doesn't work.

Apparently, my second stream gets terminally-evaluated on the spot the first time it is used on the pipeline. It doesn't produce a lazy stream that I can then consume at my own pace.

I think the .forEach in this piece of code from ReferencePipeline#flatMap is to blame:

@Override
public void accept(P_OUT u) {
    try (Stream<? extends R> result = mapper.apply(u)) {
        if (result != null) {
            if (!cancellationRequestedCalled) {
               result.sequential().forEach(downstream);
            }
            else {
                var s = result.sequential().spliterator();
                do { } while (!downstream.cancellationRequested() && s.tryAdvance(downstream));
            }
        }
    }
}

I expected the above code to return 20 elements looking like:

a
ax
axx
axxx
axxxx
...
axxxxxxxxxxxxxxxxxxx

But instead it crashes with an OutOfMemoryError, since the very long Stream of the nested flatMap is evaluated eagerly (??) and fills up my memory with unnecessary copies of the repeated strings. If instead of Integer.MAX_VALUE, a value of 3 was provided, keeping the same limit at 20, expected output would instead be:

a
ax
axx
axxx
a
ay
ayy
ayyy
b
bx
bxx
bxxx
...
(up until 20 lines)

Edit: At this point I have just rolled my own implementation with lazy iterators. Still, I think there should be a way to do this with pure Streams.

Edit 2: This has been admitted as a bug ticket in Java https://bugs.java.com/bugdatabase/view_bug.do?bug_id=JDK-8267758%20

like image 534
Nirro Avatar asked May 24 '21 04:05

Nirro


People also ask

Is flatMap slow?

The built-in flatMap function is a little bit slower than the for-in loop. Custom for-in loop flatMap is 1.06x faster.

Is Scala flatMap lazy?

Answer: Based on the comments and the answers below, flatmap is partially lazy. i.e reads the first stream fully and only when required, it goes for next. Reading a stream is eager but reading multiple streams is lazy.

Is flatMap lazy Java?

Stream flatMap(Function mapper) is an intermediate operation. These operations are always lazy. Intermediate operations are invoked on a Stream instance and after they finish their processing, they give a Stream instance as output.

What does flatMap do?

The flatMap() method returns a new array formed by applying a given callback function to each element of the array, and then flattening the result by one level. It is identical to a map() followed by a flat() of depth 1 ( arr.map(...args).flat() ), but slightly more efficient than calling those two methods separately.


1 Answers

As you have already written, this has been accepted as a bug. Maybe, it will be solved in a future version of Java.

But there could be a solution even now. It is not very elegant and it may be viable only if the number of elements in the outer stream and the limit are small enough. But it will work under these restrictions.

Let me first modify your example a little bit by converting the outer flatMap into two operations (a map and a flatMap with identity, doing only flatten):

Stream.of("a", "b", "c")
      .map(s -> Stream.of("x", "y")
            .flatMap(sd -> IntStream.rangeClosed(0, Integer.MAX_VALUE)
                  .mapToObj(sd::repeat)))
      .flatMap(s -> s)
      .map(s -> s + "u")
      .limit(20)
      .forEach(System.out::println);

We can easily see that we need no more than 20 elements from each inner stream. So we can limit each stream to this number of elements. This will work (you should use a varaible or constant for the limit):

Stream.of("a", "b", "c")
      .map(s -> Stream.of("x", "y")
            .flatMap(sd -> IntStream.rangeClosed(0, Integer.MAX_VALUE)
                  .mapToObj(sd::repeat)))
      .flatMap(s -> s.limit(20))            // limit each inner stream
      .map(s -> s + "u")
      .limit(20)
      .forEach(System.out::println);

Of course this will still produce too much intermediate results, but it may not be a big problem under the above restrictions.

like image 73
Donat Avatar answered Oct 21 '22 23:10

Donat