Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use a custom Collector in a groupingBy operation

The Oracle trails on reduction with streams gives an example of how to convert a collection of people into a map containing the average age based on gender. It uses the following Person class and code:

public class Person {
    private int age;

    public enum Sex {
        MALE,
        FEMALE
    }

    private Sex sex;

    public Person (int age, Sex sex) {
        this.age = age;
        this.sex = sex;
    }

    public int getAge() { return this.age; }

    public Sex getSex() { return this.sex; }
}

Map<Person.Sex, Double> averageAgeByGender = roster
    .stream()
    .collect(
        Collectors.groupingBy(
            Person::getSex,                      
            Collectors.averagingInt(Person::getAge)));

The above stream code works great, but I wanted to see how to do the same operation while using a custom implementation of a collector. I could find no complete example of how to do this either on Stack Overflow or the net. As to why we might want to do this, as an example, perhaps we would want to compute some kind of weighted average involving the age. In this case, the default behavior of Collectors.averagingInt would not suffice.

like image 915
Tim Biegeleisen Avatar asked Jan 27 '23 09:01

Tim Biegeleisen


2 Answers

Just use Collector.of(Supplier, BiConsumer, BinaryOperator, [Function,] Characteristics...) for those cases:

Collector.of(() -> new double[2],
        (a, t) -> { a[0] += t.getAge(); a[1]++; },
        (a, b) -> { a[0] += b[0]; a[1] += b[1]; return a; },
        a -> (a[1] == 0) ? 0.0 : a[0] / a[1])
)

Although it might be more readable to define a PersonAverager:

class PersonAverager {
    double sum = 0;
    int count = 0;

    void accept(Person p) {
        sum += p.getAge();
        count++;
    }

    PersonAverager combine(PersonAverager other) {
        sum += other.sum;
        count += other.count;
        return this;
    }

    double average() {
        return count == 0 ? 0 : sum / count;
    }
}

and use it as:

Collector.of(PersonAverager::new,
        PersonAverager::accept,
        PersonAverager::combine,
        PersonAverager::average)
like image 134
Didier L Avatar answered Jan 30 '23 23:01

Didier L


This answer, which has been tested, is based off a bunch of different sources. The source code for Collectors#averagingInt was helpful in figuring out the lambda syntax used below. The supplier used is a Double[] array of size two. The first index is used to store the cumulative person ages, while the second index stores the counts.

public class PersonCollector<T extends Person> implements Collector<T, double[], Double> {
    private ToIntFunction<Person> mapper;

    public PersonCollector(ToIntFunction<Person> mapper) {
        this.mapper = mapper;
    }

    @Override
    public Supplier<double[]> supplier() {
        return () -> new double[2];
    }

    @Override
    public BiConsumer<double[], T> accumulator() {
        return (a, t) -> { a[0] += mapper.applyAsInt(t); a[1]++; };
    }

    @Override
    public BinaryOperator<double[]> combiner() {
        return (a, b) -> { a[0] += b[0]; a[1] += b[1]; return a; };
    }

    @Override
    public Function<double[], Double> finisher() {
        return a -> (a[1] == 0) ? 0.0 : a[0] / a[1];
    }

    @Override
    public Set<Characteristics> characteristics() {
        // do NOT return IDENTITY_FINISH here, which would bypass
        // the custom finisher() above
        return Collections.emptySet();
    }
}

List<Person> list = new ArrayList<>();
list.add(new Person(34, Person.Sex.MALE));
list.add(new Person(23, Person.Sex.MALE));
list.add(new Person(68, Person.Sex.MALE));
list.add(new Person(14, Person.Sex.FEMALE));
list.add(new Person(58, Person.Sex.FEMALE));
list.add(new Person(27, Person.Sex.FEMALE));

final Collector<Person, double[], Double> pc = new PersonCollector<>(Person::getAge);

Map<Person.Sex, Double> averageAgeBySex = list
  .stream()
  .collect(Collectors.groupingBy(Person::getSex, pc));

System.out.println("Male average: " + averageAgeBySex.get(Person.Sex.MALE));
System.out.println("Female average: " + averageAgeBySex.get(Person.Sex.FEMALE));

This outputs:

Male average: 41.666666666666664
Female average: 33.0

Note above that we pass the method reference Person::getAge to the custom collector, which maps each Person in the collection to an integer age value. Also, we do not return Characteristics.IDENTITY_FINISH from the characateristics() method. Doing so would mean that our custom finisher() would be bypassed.

like image 31
Tim Biegeleisen Avatar answered Jan 30 '23 23:01

Tim Biegeleisen