Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does this code using streams run so much faster in Java 9 than Java 8?

I discovered this while solving Problem 205 of Project Euler. The problem is as follows:

Peter has nine four-sided (pyramidal) dice, each with faces numbered 1, 2, 3, 4. Colin has six six-sided (cubic) dice, each with faces numbered 1, 2, 3, 4, 5, 6.

Peter and Colin roll their dice and compare totals: the highest total wins. The result is a draw if the totals are equal.

What is the probability that Pyramidal Pete beats Cubic Colin? Give your answer rounded to seven decimal places in the form 0.abcdefg

I wrote a naive solution using Guava:

import com.google.common.collect.Sets;
import com.google.common.collect.ImmutableSet;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.*;
import java.util.stream.Collectors;

public class Problem205 {
    public static void main(String[] args) {
        long startTime = System.currentTimeMillis();
        List<Integer> peter = Sets.cartesianProduct(Collections.nCopies(9, ImmutableSet.of(1, 2, 3, 4)))
                .stream()
                .map(l -> l
                        .stream()
                        .mapToInt(Integer::intValue)
                        .sum())
                .collect(Collectors.toList());
        List<Integer> colin = Sets.cartesianProduct(Collections.nCopies(6, ImmutableSet.of(1, 2, 3, 4, 5, 6)))
                .stream()
                .map(l -> l
                        .stream()
                        .mapToInt(Integer::intValue)
                        .sum())
                .collect(Collectors.toList());

        long startTime2 = System.currentTimeMillis();
        // IMPORTANT BIT HERE! v
        long solutions = peter
                .stream()
                .mapToLong(p -> colin
                        .stream()
                        .filter(c -> p > c)
                        .count())
                .sum();

        // IMPORTANT BIT HERE! ^
        System.out.println("Counting solutions took " + (System.currentTimeMillis() - startTime2) + "ms");

        System.out.println("Solution: " + BigDecimal
                .valueOf(solutions)
                .divide(BigDecimal
                                .valueOf((long) Math.pow(4, 9) * (long) Math.pow(6, 6)),
                        7,
                        RoundingMode.HALF_UP));
        System.out.println("Found in: " + (System.currentTimeMillis() - startTime) + "ms");
    }
}

The code I have highlighted, which uses a simple filter(), count() and sum(), seems to run much faster in Java 9 than Java 8. Specifically, Java 8 counts the solutions in 37465ms on my machine. Java 9 does it in about 16000ms, which is the same whether I run the file compiled with Java 8 or one compiled with Java 9.

If I replace the streams code with what would seem to be the exact pre-streams equivalent:

long solutions = 0;
for (Integer p : peter) {
    long count = 0;
    for (Integer c : colin) {
        if (p > c) {
            count++;
        }
    }
    solutions += count;
}

It counts the solutions in about 35000ms, with no measurable difference between Java 8 and Java 9.

What am I missing here? Why is the streams code so much faster in Java 9, and why isn't the for loop?


I am running Ubuntu 16.04 LTS 64-bit. My Java 8 version:

java version "1.8.0_131"
Java(TM) SE Runtime Environment (build 1.8.0_131-b11)
Java HotSpot(TM) 64-Bit Server VM (build 25.131-b11, mixed mode)

My Java 9 version:

java version "9"
Java(TM) SE Runtime Environment (build 9+181)
Java HotSpot(TM) 64-Bit Server VM (build 9+181, mixed mode)
like image 564
bcsb1001 Avatar asked Oct 07 '17 00:10

bcsb1001


Video Answer


1 Answers

1. Why the stream works faster on JDK 9

Stream.count() implementation is rather dumb in JDK 8: it just iterates through the whole stream adding 1L for each element.

This was fixed in JDK 9. Even though the bug report says about SIZED streams, new code improves non-sized streams, too.

If you replace .count() with Java 8-style implementation .mapToLong(e -> 1L).sum(), it will be slow again even on JDK 9.

2. Why naive loop works slow

When you put all your code in main method, it cannot be JIT-compiled efficiently. This method is executed only once, it starts running in interpreter and later, when JVM detects a hot loop, it switches from interpreted mode to compiled on-the-go. This is called on-stack replacement (OSR).

OSR compilations are often not as optimized as regular compiled methods. I've explained this in detail earlier, see this and this answer.

JIT will produce better code if you put the inner loop in a separate method:

    long solutions = 0;
    for (Integer p : peter) {
        solutions += countLargerThan(colin, p);
    }

    ...

    private static int countLargerThan(List<Integer> colin, int p) {
        int count = 0;
        for (Integer c : colin) {
            if (p > c) {
                count++;
            }
        }
        return count;
    }

In this case countLargerThan method will be compiled normally, and the performance will be better than with streams both on JDK 8 and on JDK 9.

like image 96
apangin Avatar answered Oct 26 '22 23:10

apangin