Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I parallelize sorting?

I would like to sort a large number of things.

Julia's standard library sort is single threaded. How can I take advantage of my multicore machine to sort things faster?

like image 950
Lyndon White Avatar asked Nov 11 '17 07:11

Lyndon White


2 Answers

Here is a solution using the (kind of experimental) Base.Threads threading module.

A solution using a pmap (etc) for a distributed parallelism would be similar. Though I think the inter-process communications overhead would hurt you.

The idea is to sort it in blocks (one per a thread), so each thread can be fully independent, just taking care of its blocks.

Then comes to merge these pre-sorted blocks.

This is a fairly well known problem of merging sorted lists. See also other questions on that.

And don't forget to set yourself up the multi-threading, by setting the environment variable JULIA_NUM_THREADS before you start.

Here is my code:

using Base.Threads

function blockranges(nblocks, total_len)
    rem = total_len % nblocks
    main_len = div(total_len, nblocks)

    starts=Int[1]
    ends=Int[]
    for ii in 1:nblocks
        len = main_len
        if rem>0
            len+=1
            rem-=1
        end
        push!(ends, starts[end]+len-1)
        push!(starts, ends[end] + 1)
    end
    @assert ends[end] == total_len
    starts[1:end-1], ends
end

function threadedsort!(data::Vector)
    starts, ends = blockranges(nthreads(), length(data))

    # Sort each block
    @threads for (ss, ee) in collect(zip(starts, ends))
        @inbounds sort!(@view data[ss:ee])
    end


    # Go through each sorted block taking out the smallest item and putting it in the new array
    # This code could maybe be optimised. see https://stackoverflow.com/a/22057372/179081
    ret = similar(data) # main bit of allocation right here. avoiding it seems expensive.
    # Need to not overwrite data we haven't read yet
    @inbounds for ii in eachindex(ret)
        minblock_id = 1
        ret[ii]=data[starts[1]]
        @inbounds for blockid in 2:endof(starts) # findmin allocates a lot for some reason, so do the find by hand. (maybe use findmin! ?)
            ele = data[starts[blockid]]
            if ret[ii] > ele
                ret[ii] = ele
                minblock_id = blockid
            end
        end
        starts[minblock_id]+=1 # move the start point forward
        if starts[minblock_id] > ends[minblock_id]
            deleteat!(starts, minblock_id)
            deleteat!(ends, minblock_id)
        end
    end
    data.=ret  # copy back into orignal as we said we would do it inplace
    return data
end

I have done some benchmarking:

using Plots
function evaluate_timing(range)
    sizes = Int[]
    threadsort_times = Float64[]
    sort_times = Float64[]
        for sz in 2.^collect(range)
            data_orig = rand(Int, sz)
            push!(sizes, sz)

            data = copy(data_orig)
            push!(sort_times,       @elapsed sort!(data))

            data = copy(data_orig)
            push!(threadsort_times, @elapsed threadedsort!(data))

            @show (sz, sort_times[end], threadsort_times[end])
    end
    return sizes, threadsort_times, sort_times
end

sizes, threadsort_times, sort_times = evaluate_timing(0:28)
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"])
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"], xscale=:log10, yscale=:log10)

My results: using 8 threads.

plot normal scale plot loglog scale

I found the crossover point to be surprisingly low, a bit over 1024. Notes that the initial long time taken can be ignored -- that is the code being JIT compiled for the first run.

Weirdly, these results do not reproduce when using BenchmarkTools. Benchmark tools would have stopped that initial timing being counted. But they do very consistently reproduce when using normal timing code as I have in the benchmark code above. I guess it is doing something that kills the multithreading some how

Big thanks to @xiaodai who pointed out a mistake in my analysis code

like image 67
Lyndon White Avatar answered Oct 30 '22 13:10

Lyndon White


I have further tested if there are only 1% of the items are unique and enter image description herealso sampling from 1:1_000_000. Results are below

enter image description here function evaluate_timing_w_repeats(range) sizes = Int[] threadsort_times = Float64[] sort_times = Float64[] for sz in 2.^collect(range) data_orig = rand(rand(Int, sz÷100), sz) push!(sizes, sz)

            data = copy(data_orig)
            push!(sort_times,       @elapsed sort!(data))

            data = copy(data_orig)
            push!(threadsort_times, @elapsed threadedsort!(data))

            @show (sz, sort_times[end], threadsort_times[end])
    end
    return sizes, threadsort_times, sort_times
end

sizes, threadsort_times, sort_times = evaluate_timing_w_repeats(7:28)
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"])
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"], xscale=:log10, yscale=:log10)
savefig("sort_with_repeats.png")

function evaluate_timing1m(range)
    sizes = Int[]
    threadsort_times = Float64[]
    sort_times = Float64[]
        for sz in 2.^collect(range)
            data_orig = rand(1:1_000_000, sz)
            push!(sizes, sz)

            data = copy(data_orig)
            push!(sort_times,       @elapsed sort!(data))

            data = copy(data_orig)
            push!(threadsort_times, @elapsed threadedsort!(data))

            @show (sz, sort_times[end], threadsort_times[end])
    end
    return sizes, threadsort_times, sort_times
end

sizes, threadsort_times, sort_times = evaluate_timing1m(7:28)
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"])
plot(sizes, [threadsort_times sort_times]; title="Sorting Time sampel from 1:1_000_000", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"], xscale=:log10, yscale=:log10)
savefig("sort1m.png")
like image 27
xiaodai Avatar answered Oct 30 '22 13:10

xiaodai