Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to downsample xarray dataset using groupby?

I'd like to downsample an xarray dataset based on a particular group, so I'm using groupby to select the group and then take 10% of the samples within each group. I'm using the code below but I get IndexError: index 1330 is out of bounds for axis 0 with size 1330 which suggests to me that my function is returning an empty array, but subset definitely has nonzero dimensions.

I was using squeeze=True which I thought would allow for new dimensions as per the GroupBy documentation but that didn't help, so I changed it to squeeze=False.

Do you know what may be happening? Thank you!

# Set random seed for reproducibility
np.random.seed(0)

def select_random_cell_subset(x):
    size = int(0.1 * len(x.cell))
    random_cells = sorted(np.random.choice(x.cell, size=size, replace=False))
    print('number of random cells:', len(random_cells))
    print('\tsome random cells:', random_cells[:5])
    subset = x.sel(cell=random_cells)
    print('subset:', subset)
    return subset

# squeeze=False because the final dataset is smaller than the original
ds_subset = ds.groupby('group', squeeze=True).apply(select_random_cell_subset)
ds_subset

Here is the error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-44-39c7803e9e40> in <module>()
     12 
     13 # squeeze=False because the final dataset is smaller than the original
---> 14 ds_subset = ds.groupby('group', squeeze=True).apply(select_random_cell_subset)
     15 ds_subset

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in apply(self, func, **kwargs)
    615         kwargs.pop('shortcut', None)  # ignore shortcut if set (for now)
    616         applied = (func(ds, **kwargs) for ds in self._iter_grouped())
--> 617         return self._combine(applied)
    618 
    619     def _combine(self, applied):

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in _combine(self, applied)
    622         coord, dim, positions = self._infer_concat_args(applied_example)
    623         combined = concat(applied, dim)
--> 624         combined = _maybe_reorder(combined, dim, positions)
    625         if coord is not None:
    626             combined[coord.name] = coord

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in _maybe_reorder(xarray_obj, dim, positions)
    443         return xarray_obj
    444     else:
--> 445         return xarray_obj[{dim: order}]
    446 
    447 

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/dataset.py in __getitem__(self, key)
    716         """
    717         if utils.is_dict_like(key):
--> 718             return self.isel(**key)
    719 
    720         if hashable(key):

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/dataset.py in isel(self, drop, **indexers)
   1141         for name, var in iteritems(self._variables):
   1142             var_indexers = dict((k, v) for k, v in indexers if k in var.dims)
-> 1143             new_var = var.isel(**var_indexers)
   1144             if not (drop and name in var_indexers):
   1145                 variables[name] = new_var

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/variable.py in isel(self, **indexers)
    568             if dim in indexers:
    569                 key[i] = indexers[dim]
--> 570         return self[tuple(key)]
    571 
    572     def squeeze(self, dim=None):

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/variable.py in __getitem__(self, key)
    398         dims = tuple(dim for k, dim in zip(key, self.dims)
    399                      if not isinstance(k, integer_types))
--> 400         values = self._indexable_data[key]
    401         # orthogonal indexing should ensure the dimensionality is consistent
    402         if hasattr(values, 'ndim'):

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/indexing.py in __getitem__(self, key)
    476     def __getitem__(self, key):
    477         key = self._convert_key(key)
--> 478         return self._ensure_ndarray(self.array[key])
    479 
    480     def __setitem__(self, key, value):

IndexError: index 1330 is out of bounds for axis 0 with size 1330
like image 259
Olga Botvinnik Avatar asked Nov 20 '25 09:11

Olga Botvinnik


2 Answers

Here's how I implemented it. As @shoyer suggested above, I returned a boolean xarray.DataArray for each group and then used that boolean to subset my data.

# Set random seed for reproducibility
np.random.seed(0)

def select_random_cell_subset(x, threshold=0.1):
    random_bools = xr.DataArray(np.random.uniform(size=len(x.cell)) <= threshold,
                               coords=dict(cell=x.cell)) 
    return random_bools

    subset_bools = ds.groupby('group',).apply(select_random_cell_subset, 
                                                    threshold=0.1)
ds_subset = ds.sel(cell=subset_bools)
like image 159
Olga Botvinnik Avatar answered Nov 22 '25 22:11

Olga Botvinnik


This is a totally sensible thing to do, but sadly it doesn't work yet. Xarray uses some heuristics to decide whether an apply operation is of the reduce or transform type, and in this case we incorrectly identify the grouped operation as a "transform" because outputs reuse the original dimension name. I just filed a bug report but unfortunately the fix for xarray will be somewhat involved.

Probably the easiest workaround would be to have the applied function return a boolean DataArray instead, indicating the positions to keep. Then you can use an indexing operation to select from the original object.

like image 24
shoyer Avatar answered Nov 22 '25 22:11

shoyer