Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Calculate weighted average with Java 8 streams

How do I go about calculating weighted mean of a Map<Double, Integer> where the Integer value is the weight for the Double value to be averaged. eg: Map has following elements:

  1. (0.7, 100) // value is 0.7 and weight is 100
  2. (0.5, 200)
  3. (0.3, 300)
  4. (0.0, 400)

I am looking to apply the following formula using Java 8 streams, but unsure how to calculate the numerator and denominator together and preserve it at the same time. How to use reduction here?

enter image description here

like image 287
Vivek Sethi Avatar asked Nov 04 '16 10:11

Vivek Sethi


People also ask

Which method can be used to compute the average of a stream?

IntStream average() method in Java The average() method of the IntStream class in Java returns an OptionalDouble describing the arithmetic mean of elements of this stream, or an empty optional if this stream is empty. It gets the average of the elements of the stream.

How can I calculate weighted average?

To find a weighted average, multiply each number by its weight, then add the results. If the weights don't add up to one, find the sum of all the variables multiplied by their weight, then divide by the sum of the weights.

How do you calculate weighted average using SUMPRODUCT?

We'll use the SUMPRODUCT and SUM functions to determine the Weighted Average. The SUMPRODUCT function multiplies each Test's score by its weight, and then, adds these resulting numbers. We then divide the outcome of SUMPRODUCT by the SUM of the weights. And this returns the Weighted Average of 80.

How is weighted ASP calculated?

In order to calculate your weighted average price per share, simply multiply each purchase price by the amount of shares purchased at that price, add them together, and then divide by the total number of shares.


3 Answers

You can create your own collector for this task:

static <T> Collector<T,?,Double> averagingWeighted(ToDoubleFunction<T> valueFunction, ToIntFunction<T> weightFunction) {
    class Box {
        double num = 0;
        long denom = 0;
    }
    return Collector.of(
             Box::new,
             (b, e) -> { 
                 b.num += valueFunction.applyAsDouble(e) * weightFunction.applyAsInt(e); 
                 b.denom += weightFunction.applyAsInt(e);
             },
             (b1, b2) -> { b1.num += b2.num; b1.denom += b2.denom; return b1; },
             b -> b.num / b.denom
           );
}

This custom collector takes two functions as parameter: one is a function returning the value to use for a given stream element (as a ToDoubleFunction), and the other returns the weight (as a ToIntFunction). It uses a helper local class storing the numerator and denominator during the collecting process. Each time an entry is accepted, the numerator is increased with the result of multiplying the value with its weight, and the denominator is increased with the weight. The finisher then returns the division of the two as a Double.

A sample usage would be:

Map<Double,Integer> map = new HashMap<>();
map.put(0.7, 100);
map.put(0.5, 200);

double weightedAverage =
  map.entrySet().stream().collect(averagingWeighted(Map.Entry::getKey, Map.Entry::getValue));
like image 197
Tunaki Avatar answered Oct 09 '22 20:10

Tunaki


You can use this procedure to calculate the weighted average of a map. Note that the key of the map entry should contain the value and the value of the map entry should contain the weight.

     /**
     * Calculates the weighted average of a map.
     *
     * @throws ArithmeticException If divide by zero happens
     * @param map A map of values and weights
     * @return The weighted average of the map
     */
    static Double calculateWeightedAverage(Map<Double, Integer> map) throws ArithmeticException {
        double num = 0;
        double denom = 0;
        for (Map.Entry<Double, Integer> entry : map.entrySet()) {
            num += entry.getKey() * entry.getValue();
            denom += entry.getValue();
        }

        return num / denom;
    }

You can look at its unit test to see a usecase.

     /**
     * Tests our method to calculate the weighted average.
     */
    @Test
    public void testAveragingWeighted() {
        Map<Double, Integer> map = new HashMap<>();
        map.put(0.7, 100);
        map.put(0.5, 200);
        Double weightedAverage = calculateWeightedAverage(map);
        Assert.assertTrue(weightedAverage.equals(0.5666666666666667));
    }

You need these imports for the unit tests:

import org.junit.Assert;
import org.junit.Test;

You need these imports for the code:

import java.util.HashMap;
import java.util.Map;

I hope it helps.

like image 3
Mohammad Avatar answered Oct 09 '22 19:10

Mohammad


public static double weightedAvg(Collection<Map.Entry<? extends Number, ? extends Number> data) {
    var sumWeights = data.stream()
        .map(Map.Entry::getKey)
        .mapToDouble(Number::doubleValue)
        .sum();
    var sumData = data.stream()
        .mapToDouble(e -> e.getKey().doubleValue() * e.getValue().doubleValue())
        .sum();
    return sumData / sumWeights;
}
like image 1
Gustavo Dias Avatar answered Oct 09 '22 20:10

Gustavo Dias