Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does HashSet.removeAll take a quadratic amount of operations?

I have this code that generates a HashSet and calls removeAll() on it. I made a class A which is just a wrapper of an int, which records the number of times equals is called - the program outputs that number.

import java.util.*;

class A {
    int x;
    static int equalsCalls;

    A(int x) {
        this.x = x;
    }

    @Override
    public int hashCode() {
        return x;
    }

    @Override
    public boolean equals(Object o) {
        equalsCalls++;
        if (!(o instanceof A)) return false;
        return x == ((A)o).x;
    }

    public static void main(String[] args) {
        int setSize = Integer.parseInt(args[0]);
        int listSize = Integer.parseInt(args[1]);
        Set<A> s = new HashSet<A>();
        for (int i = 0; i < setSize; i ++)
            s.add(new A(i));
        List<A> l = new LinkedList<A>();
        for (int i = 0; i < listSize; i++)
            l.add(new A(i));
        s.removeAll(l);
        System.out.println(A.equalsCalls);
    }
}

It turns out that the removeAll operation is not linear as I would have expected:

$ java A 10 10
55
$ java A 100 100
5050
$ java A 1000 1000
500500

In fact, it seems to be quadratic. Why?

Replacing the line s.removeAll(l); with for (A b : l) s.remove(b); fixes it to act the way I'd expect:

$ java A 10 10
10
$ java A 100 100
100
$ java A 1000 1000
1000

Why?

Here's a graph showing the quadratic curve:

enter image description here

The x and y axes are the two arguments to the Java program. The z axis is the number of A.equals calls.


The graph was generated by this Asymptote program:

import graph3;
import grid3;
import palette;

currentprojection=orthographic(0.8,1,1);

size(400,300,IgnoreAspect);

defaultrender.merge=true;

real[][] a = new real[][]{
  new real[]{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0},
  new real[]{0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1},
  new real[]{0,1,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3},
  new real[]{0,1,2,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6},
  new real[]{0,1,2,3,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10},
  new real[]{0,1,2,3,4,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15},
  new real[]{0,1,2,3,4,5,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21},
  new real[]{0,1,2,3,4,5,6,28,28,28,28,28,28,28,28,28,28,28,28,28,28},
  new real[]{0,1,2,3,4,5,6,7,36,36,36,36,36,36,36,36,36,36,36,36,36},
  new real[]{0,1,2,3,4,5,6,7,8,45,45,45,45,45,45,45,45,45,45,45,45},
  new real[]{0,1,2,3,4,5,6,7,8,9,55,55,55,55,55,55,55,55,55,55,55},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,66,66,66,66,66,66,66,66,66,66},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,78,78,78,78,78,78,78,78,78},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,12,91,91,91,91,91,91,91,91},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,105,105,105,105,105,105,105},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,120,120,120,120,120,120},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,136,136,136,136,136},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,153,153,153,153},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,171,171,171},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,190,190},
  new real[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,210},
};


surface s=surface(a,(-1/2,-1/2),(1/2,1/2),Spline);
draw(s,mean(palette(s.map(zpart),Rainbow())),black);
//grid3(XYZgrid);

The array a was generated by this Haskell program:

import System.Process
import Data.List

f :: Integer -> Integer -> IO Integer
f x y = fmap read $ readProcess "/usr/bin/java" ["A", show x, show y] ""

g :: [[Integer]] -> [String]
g xss = map (\xs -> "new real[]{" ++ intercalate "," (map show xs) ++ "},") xss

main = do
  let n = 20
  xs <- mapM (\x -> mapM (\y -> f x y) [0..n]) [0..n]
  putStrLn $ unlines $ g xs
like image 617
Dog Avatar asked Jul 25 '13 02:07

Dog


1 Answers

The time it takes for removeAll to work depends on what kind of collection you pass. When you look at the implementation of removeAll:

public boolean removeAll(Collection<?> c) {
    boolean modified = false;

    if (size() > c.size()) {
        for (Iterator<?> i = c.iterator(); i.hasNext(); )
            modified |= remove(i.next());
    } else {
        for (Iterator<?> i = iterator(); i.hasNext(); ) {
            if (c.contains(i.next())) {
                i.remove();
                modified = true;
            }
        }
    }
    return modified;
}

one can see that when the HashSet and the collection have the same size, it iterates over the HashSet and calls c.contains for each element. Since you are using a LinkedList as the argument, this is an O(n) operation for each element. Since it needs to do this for each of the n elements being removed, the result is an O(n2) operation.

If you replace the LinkedList with a collection that offers a more efficient implementation of contains, then the performance of removeAll should improve accordingly. If you make the HashSet larger than the collection, performance should also improve dramatically, since it will then iterate over the collection.

like image 76
Ted Hopp Avatar answered Sep 28 '22 18:09

Ted Hopp