Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is my bottom-up merge sort so slow in Java?

I spent the past few hours trying to figure out why the Java version of my sorting algorithm was twice as slow as a recursive merge sort, since the C and C++ versions were 40-50% faster. I kept removing more and more code until I had stripped everything down to a simple loop and merge, but it was still twice as slow. Why is this so slow only in Java?

For reference, here's what a bottom-up merge sort might look like:

public static <T> void sort(T[] a, T[] aux, Comparator<T> comp) {
    int N = a.length;
    for (int n = 1; n < N; n = n+n)
        for (int i = 0; i < N-n; i += n+n)
            merge(a, aux, i, i+n-1, Math.min(i+n+n-1, N-1), comp);
}

And here's the recursive version:

public static <T> void sort(T[] a, T[] aux, int lo, int hi, Comparator<T> comp) {
    int mid = lo + (hi - lo) / 2;
    sort(a, aux, lo, mid, comp);
    sort(a, aux, mid + 1, hi, comp);
    merge(a, aux, lo, mid, hi, comp);
}

Those are basically just copied from the algorithms on this website. As a last resort I figured I'd copy and paste something from online, but it too is twice as slow as the recursive version.

Is there something "special" about Java that I'm missing?

EDIT: As requested, here's some code:

import java.util.*;
import java.lang.*;
import java.io.*;

class Test {
    public int value;
    public int index;
}

class TestComparator implements Comparator<Test> {
    public int compare(Test a, Test b) {
        if (a.value < b.value) return -1;
        if (a.value > b.value) return 1;
        return 0;
    }
}


class Merge<T> {
    private static <T> void Merge(T[] array, int start, int mid, int end, Comparator<T> comp, T[] buffer) {
        java.lang.System.arraycopy(array, start, buffer, 0, (mid - start));
        int A_count = 0, B_count = 0, insert = 0;
        while (A_count < (mid - start) && B_count < (end - mid)) {
            if (comp.compare(array[mid + B_count], buffer[A_count]) >= 0)
                array[start + insert++] = buffer[A_count++];
            else
                array[start + insert++] = array[mid + B_count++];
        }
        java.lang.System.arraycopy(buffer, A_count, array, start + insert, (mid - start) - A_count);
    }

    private static <T> void SortR(T[] array, int start, int end, T[] buffer, Comparator<T> comp) {
        if (end - start <= 2) {
            if (end - start == 2) {
                if (comp.compare(array[start], array[end - 1]) > 0) {
                    T swap = array[start];
                    array[start] = array[end - 1];
                    array[end - 1] = swap;
                }
            }

            return;
        }

        int mid = start + (end - start)/2;
        SortR(array, start, mid, buffer, comp);
        SortR(array, mid, end, buffer, comp);
        Merge(array, start, mid, end, comp, buffer);
    }

    public static <T> void Recursive(T[] array, Comparator<T> comp) {
        @SuppressWarnings("unchecked")
        T[] buffer = (T[]) new Object[array.length];
        SortR(array, 0, array.length, buffer, comp);
    }

    public static <T> void BottomUp(T[] array, Comparator<T> comp) {
        @SuppressWarnings("unchecked")
        T[] buffer = (T[]) new Object[array.length];

        int size = array.length;
        for (int index = 0; index < size - 1; index += 2) {
            if (comp.compare(array[index], array[index + 1]) > 0) {
                T swap = array[index];
                array[index] = array[index + 1];
                array[index + 1] = swap;
            }
        }

        for (int length = 2; length < size; length += length)
            for (int index = 0; index < size - length; index += length + length)
                Merge(array, index, index + length, Math.min(index + length + length, size), comp, buffer);
    }
}


class SortRandom {
    public static Random rand;
    public static int nextInt(int max) {
        // set the seed on the random number generator
        if (rand == null) rand = new Random();
        return rand.nextInt(max);
    }
    public static int nextInt() {
        return nextInt(2147483647);
    }
}

class Sorter {
    public static void main (String[] args) throws java.lang.Exception {
        int max_size = 1500000;
        TestComparator comp = new TestComparator();

        for (int total = 0; total < max_size; total += 2048 * 16) {
            Test[] array1 = new Test[total];
            Test[] array2 = new Test[total];

            for (int index = 0; index < total; index++) {
                Test item = new Test();

                item.value = SortRandom.nextInt();
                item.index = index;

                array1[index] = item;
                array2[index] = item;
            }

            double time1 = System.currentTimeMillis();
            Merge.BottomUp(array1, comp);
            time1 = System.currentTimeMillis() - time1;

            double time2 = System.currentTimeMillis();
            Merge.Recursive(array2, comp);
            time2 = System.currentTimeMillis() - time2;

            if (time1 >= time2)
                System.out.format("%f%% as fast\n", time2/time1 * 100.0);
            else
                System.out.format("%f%% faster\n", time2/time1 * 100.0 - 100.0);

            System.out.println("verifying...");
            for (int index = 0; index < total; index++) {
                if (comp.compare(array1[index], array2[index]) != 0) throw new Exception();
                if (array2[index].index != array1[index].index) throw new Exception();
            }
            System.out.println("correct!");
        }
    }
}

And here's a C++ version:

#include <iostream>
#include <cassert>
#include <cstring>
#include <ctime>

class Test {
public:
    size_t value, index;
};

bool TestCompare(Test item1, Test item2) {
    return (item1.value < item2.value);
}

namespace Merge {
    template <typename T, typename Comparison>
    void Merge(T array[], int start, int mid, int end, Comparison compare, T buffer[]) {
        std::copy(&array[start], &array[mid], &buffer[0]);
        int A_count = 0, B_count = 0, insert = 0;
        while (A_count < (mid - start) && B_count < (end - mid)) {
            if (!compare(array[mid + B_count], buffer[A_count]))
                array[start + insert++] = buffer[A_count++];
            else
                array[start + insert++] = array[mid + B_count++];
        }
        std::copy(&buffer[A_count], &buffer[mid - start], &array[start + insert]);
    }

    template <typename T, typename Comparison>
    void SortR(T array[], int start, int end, T buffer[], Comparison compare) {
        if (end - start <= 2) {
            if (end - start == 2)
                if (compare(array[end - 1], array[start]))
                    std::swap(array[start], array[end - 1]);
            return;
        }

        int mid = start + (end - start)/2;
        SortR(array, start, mid, buffer, compare);
        SortR(array, mid, end, buffer, compare);
        Merge(array, start, mid, end, compare, buffer);
    }

    template <typename T, typename Comparison>
    void Recursive(T array[], int size, Comparison compare) {
        T *buffer = new T[size];
        SortR(array, 0, size, buffer, compare);
        delete[] buffer;
    }

    template <typename T, typename Comparison>
    void BottomUp(T array[], int size, Comparison compare) {
        T *buffer = new T[size];

        for (int index = 0; index < size - 1; index += 2) {
            if (compare(array[index + 1], array[index]))
                std::swap(array[index], array[index + 1]);
        }

        for (int length = 2; length < size; length += length)
            for (int index = 0; index < size - length; index += length + length)
                Merge(array, index, index + length, std::min(index + length + length, size), compare, buffer);

        delete[] buffer;
    }
}

int main() {
    srand(time(NULL));
    int max_size = 1500000;
    for (int total = 0; total < max_size; total += 2048 * 16) {
        Test *array1 = new Test[total];
        Test *array2 = new Test[total];

        for (int index = 0; index < total; index++) {
            Test item;
            item.value = rand();
            item.index = index;

            array1[index] = item;
            array2[index] = item;
        }

        double time1 = clock() * 1.0/CLOCKS_PER_SEC;
        Merge::BottomUp(array1, total, TestCompare);
        time1 = clock() * 1.0/CLOCKS_PER_SEC;

        double time2 = clock() * 1.0/CLOCKS_PER_SEC;
        Merge::Recursive(array2, total, TestCompare);
        time2 = clock() * 1.0/CLOCKS_PER_SEC;

        if (time1 >= time2)
           std::cout << time2/time1 * 100.0 << "% as fast" << std::endl;
        else
            std::cout << time2/time1 * 100.0 - 100.0 << "% faster" << std::endl;

        std::cout << "verifying... ";
        for (int index = 0; index < total; index++) {
            assert(array1[index].value == array2[index].value);
            assert(array2[index].index == array1[index].index);
        }
        std::cout << "correct!" << std::endl;

        delete[] array1;
        delete[] array2;
    }
    return 0;
}

The differences aren't as drastic as the original versions, but the C++ iterative version is faster while the Java iterative version is slower.

(and yeah, I realize these versions kinda suck and allocate more memory than is used)

Update 2: When I switched the bottom-up merge sort over to a postorder traversal, which closely matches the order of array accesses in the recursive version, it finally started running about 10% faster than the recursive version. So it looks like it has to do with cache misses and not micro-benchmarks or an unpredictable JVM.

The reason it only affects the Java version may be because Java lacks the custom value types used in the C++ version. I'll allocate all of the Test classes separately in the C++ version and see what happens to the performance. The sorting algorithm I'm working on can't be easily adapted to this type of traversal, but if the performance tanks in the C++ version too I might not have much of a choice.

Update 3: Nope, switching the C++ version over to allocated classes did not seem to have any appreciable effect on its performance. It sure seems like it's caused by something with Java specifically.

like image 421
BonzaiThePenguin Avatar asked Apr 16 '14 23:04

BonzaiThePenguin


People also ask

Why is merge sort so slow?

Main reason here behind increased running time is possibly the heap allocation of array, as stack allocation is faster than heap allocation (here heap means memory space , and not sorting algorithm's heap).

Is bottom-up merge sort faster?

In all the benchmarks I've run, bottom up merge sort is faster than top down, which would explain why most libraries use some variation of bottom up merge sort.

Which is better top down or bottom-up merge sort?

This means that, comparison-wise, top-down is always better than bottom-up. Furthermore, it has been proved that the "half-half" splitting strategy of top-down merge sort is optimal.

What is the average time complexity of bottom-up merge sort?

O(nlogn2)


1 Answers

Interesting question. I couldn't figure out why bottomUp version is slower than recursive, while with array size of power of two they work identicaly.

At least bottomUp is slower just a bit, not twice.

Benchmark                             Mode          Mean   Mean error    Units
RecursiveVsBottomUpSort.bottomUp      avgt        64.436        0.376    us/op
RecursiveVsBottomUpSort.recursive     avgt        58.902        0.552    us/op

Code:

@OutputTimeUnit(TimeUnit.MICROSECONDS)
@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 10, time = 1)
@State(Scope.Thread)
@Threads(1)
@Fork(1)
public class RecursiveVsBottomUpSort {

    static final int N = 1024;
    int[] a = new int[N];
    int[] aux = new int[N];

    @Setup(Level.Invocation)
    public void fill() {
        Random r = ThreadLocalRandom.current();
        for (int i = 0; i < N; i++) {
            a[i] = r.nextInt();
        }
    }

    @GenerateMicroBenchmark
    public static int bottomUp(RecursiveVsBottomUpSort st) {
        int[] a = st.a, aux = st.aux;
        int N = a.length;
        for (int n = 1; n < N; n = n + n) {
            for (int i = 0; i < N - n; i += n + n) {
                merge(a, aux, i, i + n - 1, Math.min(i + n + n - 1, N - 1));
            }
        }
        return a[N - 1];
    }

    @GenerateMicroBenchmark
    public static int recursive(RecursiveVsBottomUpSort st) {
        sort(st.a, st.aux, 0, N - 1);
        return st.a[N - 1];
    }

    static void sort(int[] a, int[] aux, int lo, int hi) {
        if (lo == hi)
            return;
        int mid = lo + (hi - lo) / 2;
        sort(a, aux, lo, mid);
        sort(a, aux, mid + 1, hi);
        merge(a, aux, lo, mid, hi);
    }

    static void merge(int[] a, int[] aux, int lo, int mid, int hi) {
        System.arraycopy(a, lo, aux, lo, mid + 1 - lo);

        for (int j = mid+1; j <= hi; j++)
            aux[j] = a[hi-j+mid+1];

        int i = lo, j = hi;
        for (int k = lo; k <= hi; k++)
            if (aux[j] < aux[i]) a[k] = aux[j--];
            else                      a[k] = aux[i++];
    }
}
like image 152
leventov Avatar answered Sep 28 '22 15:09

leventov