Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Parallelized bootstrapping with replacement with xarray/dask

I want to perform N=1000 bootstrapping with replacement on gridded data. One computation takes about 0.5s. I have access to a supercomputer exclusive node with 48 cores. Because the resampling are independent of each other, I naively hope to distribute the workload on all or at least many cores and get a performance increase by .8 * ncores. But I dont get it.

I still lack proper understand about dask. Based on Best practices in setting number of dask workers, I use:

from dask.distributed import Client
client = Client(processes=False, threads_per_worker=8, n_workers=6, memory_limit=‘32GB')

I also tried with SLURMCluster, but I guess I first need to understand what I do and then scale.

My MWE:

  1. create sample data
  2. write function I want to apply
  3. write resampling inits function
  4. write bootstrapping function with bootstrap (=N) as argument: see many implementations below
  5. perform bootstrapping
import dask
import numpy as np
import xarray as xr
from dask.distributed import Client

inits = np.arange(50)
lats = np.arange(96)
lons = np.arange(192)
data = np.random.rand(len(inits), len(lats), len(lons))
a = xr.DataArray(data,
                        coords=[inits, lats, lons],
                        dims=['init', 'lat', 'lon'])

data = np.random.rand(len(inits), len(lats), len(lons))
b = xr.DataArray(data,
                        coords=[inits, lats, lons],
                        dims=['init', 'lat', 'lon'])

def func(a,b, dim='init'):
    return (a-b).std(dim)

bootstrap=96

def resample(a):
    smp_init = np.random.choice(inits, len(inits))
    smp_a = a.sel(init=smp_init)
    smp_a['init'] = inits
    return smp_a


# serial function
def bootstrap_func(bootstrap=bootstrap):
    res = (func(resample(a),b) for _ in range(bootstrap))
    res = xr.concat(res,'bootstrap')
    # leave out quantile because not issue here yet
    #res_ci = res.quantile([.05,.95],'bootstrap')
    return res


@dask.delayed
def bootstrap_func_delayed_decorator(bootstrap=bootstrap):
    return bootstrap_func(bootstrap=bootstrap)


def bootstrap_func_delayed(bootstrap=bootstrap):
    res = (dask.delayed(func)(resample(a),b) for _ in range(bootstrap))
    res = xr.concat(dask.compute(*res),'bootstrap')
    #res_ci = res.quantile([.05,.95],'bootstrap')
    return res

for scheduler in ['synchronous','distributed','multiprocessing','processes','single-threaded','threads']:
    print('scheduler:',scheduler)

    def bootstrap_func_delayed_processes(bootstrap=bootstrap):
        res = (dask.delayed(func)(resample(a),b) for _ in range(bootstrap))
        res = xr.concat(dask.compute(*res, scheduler=scheduler),'bootstrap')
        res = res.quantile([.05,.95],'bootstrap')
        return res

    %time c = bootstrap_func_delayed_processes()

The following results are from my 4 core laptop. But on the supercomputer I also see no speedup, rather decrease by 50%.

Results for serial:

%time c = bootstrap_func()
CPU times: user 814 ms, sys: 58.7 ms, total: 872 ms
Wall time: 862 ms

Results for parallel:

%time c = bootstrap_func_delayed_decorator().compute()
CPU times: user 96.2 ms, sys: 50 ms, total: 146 ms
Wall time: 906 ms

Results for parallelized from the loop:

scheduler: synchronous
CPU times: user 2.57 s, sys: 330 ms, total: 2.9 s
Wall time: 2.95 s
scheduler: distributed
CPU times: user 4.51 s, sys: 2.74 s, total: 7.25 s
Wall time: 8.86 s
scheduler: multiprocessing
CPU times: user 4.18 s, sys: 2.53 s, total: 6.71 s
Wall time: 7.95 s
scheduler: processes
CPU times: user 3.97 s, sys: 2.1 s, total: 6.07 s
Wall time: 7.39 s
scheduler: single-threaded
CPU times: user 2.26 s, sys: 275 ms, total: 2.54 s
Wall time: 2.47 s
scheduler: threads
CPU times: user 2.84 s, sys: 341 ms, total: 3.18 s
Wall time: 2.66 s

Expected results: - speedup (by .8 * ncores)

Other considerations: - I also checked whether I should chunk my data. too sample chunks. chunked arrays take longer.

My questions: - What did I get wrong about dask parallelization? - Is the client setup not useful that way? - Did I implement dask.delayed not clever enough? - Is my serial function already executed in parallel because of dask? I think not.

like image 335
aaron.spring Avatar asked May 23 '19 18:05

aaron.spring


1 Answers

I finally solved this. When posting this challenge, I obviously didn't understand a few aspects of it:

  • I ran the timings on a laptop with two physical cores. This doesn't allow much parallelization in a CPU-bound problem. Now I ran this on a node with 48 logical CPUs
  • I should have thought about which parts of the algorithm are easily parallelizable and which parts are not. Only then I can chunk accordingly.

See my solution here: https://gist.github.com/aaronspring/118abd7b9bf81e555b1fced42eef427f

The game-changers wrt. the code posted initially:

  • I chunk a dimension (here x) with is not involved in the func (which uses time)
  • I still use the client as mentioned above: Best practices in setting number of dask workers
  • I only try to parallelize the iteration part. The quantile method is done in memory.

Conclusion: It is simpler than expected. The gist shows an implementation with dask.delayed and dask.futures but thats not even needed in my use case. First try to understand parallelism https://realpython.com/python-concurrency/ and read the dask documentation https://dask.org/.

like image 197
aaron.spring Avatar answered Oct 26 '22 22:10

aaron.spring