Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Java 8 streams: can you capture/reuse a value calculated in a filter?

I'm trying to convert an "old way" loop into a streams-based approach. The loop takes one large set of elements and returns a subset that fall within a given radius. The results are sorted by distance, and the results themselves have the calculated distance handy (for presentation). It works fine the old way, and I don't need to Java8-ify it. But I really want to. :-) If only to be able to go .parallel() on this sucker.

The catch is...my filter() uses a calculated value (the distance), which I then need in a subsequent map() step (to construct the "with distance" instance). Assume the distance calculation is expensive. Here's the Java 7 way...scroll down to see the getNearestStations() method:

public interface Coordinate {
    double distanceTo(Coordinate other);
}

public class Station {
    private final String name;
    private final Coordinate coordinate;

    public Station(String name, Coordinate coordinate) {
        this.name = name;
        this.coordinate = coordinate;
    }

    public String getName() {
        return name;
    }

    public Coordinate getCoordinate() {
        return coordinate;
    }
}

public class StationWithDistance extends Station implements Comparable<StationWithDistance> {
    private final double distance;

    public StationWithDistance(Station station, double distance) {
        super(station.getName(), station.getCoordinate());
        this.distance = distance;
    }

    public double getDistance() {
        return distance;
    }

    public int compareTo(StationWithDistance s2) {
        return Double.compare(this.distance, s2.distance);
    }
}

// Assume this contains many entries
private final List<Station> allStations = new ArrayList<>();

public List<StationWithDistance> getNearbyStations(Coordinate origin, double radius) {
    List<StationWithDistance> nearbyStations = new ArrayList<>();
    for (Station station : allStations) {
        double distance = origin.distanceTo(station.getCoordinate());
        if (distance <= radius) {
            nearbyStations.add(new StationWithDistance(station, distance));
        }
    }
    Collections.sort(nearbyStations);
    return nearbyStations;
}

Now...here's a dumb/brute force streams-based approach. Note that it performs the distance calculation twice (stoopid), but it's a step closer to parallel()ized:

public List<StationWithDistance> getNearbyStationsNewWay(Coordinate origin, double radius) {
    return allStations.stream()
        .parallel()
        .filter(s -> origin.distanceTo(s.getCoordinate()) <= radius)
        .map(s -> new StationWithDistance(s, origin.distanceTo(s.getCoordinate())))
        .sorted()
        .collect(Collectors.toList());
}

Trying to figure out a Better Way(tm), this is all I've come up with so far in order to avoid the duplicate calculation:

public List<StationWithDistance> getNearbyStationsNewWay(Coordinate origin, double radius) {
    return allStations.stream()
        .parallel()
        .map(s -> new StationWithDistance(s, origin.distanceTo(s.getCoordinate())))
        .filter(s -> s.getDistance() <= radius)
        .sorted()
        .collect(Collectors.toList());
}

...but that produces garbage -- most of the StationWithDistance instances created get filtered out.

What am I missing? Is there an elegant way to do this in Java 8 that (a) avoids the duplicate calculation, and (b) doesn't produce unwanted garbage?

I could do this with a forEach() call, mixing old & new methodologies...so I could at least take advantage of streams, but "optimize" the calc/filter/add in an old-school way. There's gotta be a nice, easy, elegant solution to this. Help me see the light...

like image 409
user3184922 Avatar asked Oct 21 '22 04:10

user3184922


2 Answers

You can use flatMap to fuse a filter and a map stage, which allows you to hold the computed distance in a local variable until you know it's useful to create a new object.

Here, I've extracted the flatmapper into a helper method, since I prefer that style, but it's certainly possibly to inline it as a statement lambda (or even an expression lambda that uses the ? : ternary operator).

Stream<StationWithDistance> nearbyStation(Station s, Coordinate origin, double radius) {
    double distance = origin.distanceTo(s.getCoordinate());
    if (distance <= radius) {
        return Stream.of(new StationWithDistance(s, distance));
    } else {
        return Stream.empty();
    }
}

public List<StationWithDistance> getNearbyStationsNewerWay(Coordinate origin, double radius) {
    return allStations.stream()
        .parallel()
        .flatMap(s -> nearbyStation(s, origin, radius))
        .sorted()
        .collect(Collectors.toList());
}
like image 134
Stuart Marks Avatar answered Oct 23 '22 10:10

Stuart Marks


I didn't actually test that it works, but I think you can compute the distance in the map stage, return StationWithDistance or null based on the distance, and then filter out the nulls.

public List<StationWithDistance> getNearbyStationsNewWay(Coordinate origin, double radius) {
    return allStations.stream()
        .parallel()
        .map(s -> {
                     double dist = origin.distanceTo(s.getCoordinate());
                     return (dist <= radius)?(new StationWithDistance(s, dist)):null;
                  })
        .filter(s -> s != null)
        .sorted()
        .collect(Collectors.toList());
}
like image 37
Eran Avatar answered Oct 23 '22 11:10

Eran