Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

xarray too slow for performance critical code

I planned to use xarray extensively in some numerically intensive scientific code that I am writing. So far, it makes the code very elegant, but I think I will have to abandon it as the performance cost is far too high.

Here is an example, which creates two arrays and multiplies parts of them together using xarray (with several indexing schemes), and numpy. I used num_comp=2 and num_x=10000:

Line #      Hits     Time   Per Hit   % Time  Line Contents
 4                                           @profile
 5                                           def xr_timing(num_comp, num_x):
 6         1         4112   4112.0     10.1      da1 = xr.DataArray(np.random.random([num_comp, num_x]).astype(np.float32), dims=['component', 'x'], coords={'component': ['a', 'b'], 'x': np.linspace(0, 1, num_x)})
 7         1          438    438.0      1.1      da2 = da1.copy()
 8         1         1398   1398.0      3.4      da2[:] = np.random.random([num_comp, num_x]).astype(np.float32)
 9         1         7148   7148.0     17.6      da3 = da1.isel(component=0).drop('component') * da2.isel(component=0).drop('component')
10         1         6298   6298.0     15.5      da4 = da1[dict(component=0)].drop('component') * da2[dict(component=0)].drop('component')
11         1         7541   7541.0     18.6      da5 = da1.sel(component='a').drop('component') * da2.sel(component='a').drop('component')
12         1         7184   7184.0     17.7      da6 = da1.loc[dict(component='a')].drop('component') * da2.loc[dict(component='a')].drop('component')
13         1         6479   6479.0     16.0      da7 = da1[0, :].drop('component') * da2[0, :].drop('component')

15                                           @profile
16                                           def np_timing(num_comp, num_x):
17         1         1027   1027.0     50.2      da1 = np.random.random([num_comp, num_x]).astype(np.float32)
18         1          977    977.0     47.8      da2 = np.random.random([num_comp, num_x]).astype(np.float32)
19         1           41     41.0      2.0      da3 = da1[0, :] * da2[0, :]

The fastest xarray multiplication takes about 150X the time of the numpy version. This is just one of the operations in my code, but I find most of them are many times slower than the numpy equivalent, which is unfortunate as xarray makes the code so much clearer. Am I doing something wrong?

Update: Even da1[0, :].values * da2[0, :].values (which loses many of the benefits of using xarray) takes 2464 time units.

I am using xarray 0.9.6, pandas 0.21.0, numpy 1.13.3, and Python 3.5.2.

Update 2: As requested by @Maximilian, here is a re-run with num_x=1000000:

Line #      Hits   Time    Per Hit   % Time  Line Contents
# xarray
 9         5       408596  81719.2     11.3      da3 = da1.isel(component=0).drop('component') * da2.isel(component=0).drop('component')
10         5       407003  81400.6     11.3      da4 = da1[dict(component=0)].drop('component') * da2[dict(component=0)].drop('component')
11         5       411248  82249.6     11.4      da5 = da1.sel(component='a').drop('component') * da2.sel(component='a').drop('component')
12         5       411730  82346.0     11.4      da6 = da1.loc[dict(component='a')].drop('component') * da2.loc[dict(component='a')].drop('component')
13         5       406757  81351.4     11.3      da7 = da1[0, :].drop('component') * da2[0, :].drop('component')
14         5        48800   9760.0      1.4      da8 = da1[0, :].values * da2[0, :].values

# numpy
20         5        37476   7495.2      2.9      da3 = da1[0, :] * da2[0, :]

The performance difference has decreased substantially, as expected (only about 10X slower now), but I am still glad that the issue will be mentioned in the next release of the documentation as even this amount of difference may surprise some people.

like image 396
user3708067 Avatar asked Nov 08 '17 12:11

user3708067


1 Answers

Yes, this is a known limitation for xarray. Performance sensitive code that uses small arrays is much slower for xarray than NumPy. I wrote a new section about this in our docs for the next version: http://xarray.pydata.org/en/stable/computation.html#wrapping-custom-computation

You basically have two options:

  1. Write your performance sensitive code on unwrapped arrays, and then wrap them back in xarray data structures. Xarray v0.10 has a new helper function (apply_ufunc) that makes this a little easier. See the link above if you are interested in this.
  2. Use something other than xarray/Python to do your computation. This could also make sense because Python itself adds significant overhead. Julia's AxisArrays.jl looks like interesting, though I haven't tried it myself.

I suppose option 3 would be to rewrite xarray itself in C++ (e.g., on top of xtensor), but that would be much more involved!

like image 107
shoyer Avatar answered Sep 20 '22 00:09

shoyer