Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient way to stack Dask Arrays generated from Xarray

So I am trying to read a large amount of relatively large netCDF files containing hydrologic data. The NetCDF files all look like this:

<xarray.Dataset>
Dimensions:         (feature_id: 2729077, reference_time: 1, time: 1)
Coordinates:
  * time            (time) datetime64[ns] 1993-01-11T21:00:00
  * reference_time  (reference_time) datetime64[ns] 1993-01-01
  * feature_id      (feature_id) int32 101 179 181 183 185 843 845 847 849 ...
Data variables:
    streamflow      (feature_id) float64 dask.array<shape=(2729077,), chunksize=(50000,)>
    q_lateral       (feature_id) float64 dask.array<shape=(2729077,), chunksize=(50000,)>
    velocity        (feature_id) float64 dask.array<shape=(2729077,), chunksize=(50000,)>
    qSfcLatRunoff   (feature_id) float64 dask.array<shape=(2729077,), chunksize=(50000,)>
    qBucket         (feature_id) float64 dask.array<shape=(2729077,), chunksize=(50000,)>
    qBtmVertRunoff  (feature_id) float64 dask.array<shape=(2729077,), chunksize=(50000,)>
Attributes:
    featureType:                timeSeries
    proj4:                      +proj=longlat +datum=NAD83 +no_defs
    model_initialization_time:  1993-01-01_00:00:00
    station_dimension:          feature_id
    model_output_valid_time:    1993-01-11_21:00:00
    stream_order_output:        1
    cdm_datatype:               Station
    esri_pe_string:             GEOGCS[GCS_North_American_1983,DATUM[D_North_...
    Conventions:                CF-1.6
    model_version:              NWM 1.2
    dev_OVRTSWCRT:              1
    dev_NOAH_TIMESTEP:          3600
    dev_channel_only:           0
    dev_channelBucket_only:     0
    dev:                        dev_ prefix indicates development/internal me...

I have 25 years worth of this data, and it is recorded hourly. So there is about 4 TB of data total.

Right now I am just trying to get seasonal averages (Daily and Monthly) of the streamflow values. So I created the following script.

import xarray as xr
import dask.array as da
from dask.distributed import Client
import os

workdir = '/path/to/directory/of/files'
files = [os.path.join(workdir, i) for i in os.listdir(workdir)]

client = Client(processes=False, threads_per_worker=4, n_workers=4, memory_limit='750MB')

big_array = []

for i, file in enumerate(files):
    ds = xr.open_dataset(file, chunks={"feature_id": 50000})

    if i == 0:
        print(ds)

    print(ds.streamflow)

    big_array.append(ds.streamflow)

    ds.close()

    if i == 5:
        break

dask_big_array = da.stack(big_array, axis=0)

print(dask_big_array)

The ds.streamflow object looks like this when printed, and from what I understand it is just a Dask array:

<xarray.DataArray 'streamflow' (feature_id: 2729077)>
dask.array<shape=(2729077,), dtype=float64, chunksize=(50000,)>
Coordinates:
  * feature_id  (feature_id) int32 101 179 181 183 185 843 845 847 849 851 ...
Attributes:
    long_name:    River Flow
    units:        m3 s-1
    coordinates:  latitude longitude
    valid_range:  [       0 50000000]

The weird thing is that when I stack the arrays, they seem to lose the chunking that I applied to them earlier. When I print out the big_array object I get this:

dask.array<stack, shape=(6, 2729077), dtype=float64, chunksize=(1, 2729077)>

The problem I am running into is when I try to run this code I get this warning, and then I think the memory gets overloaded so I have to kill the process.

distributed.worker - WARNING - Memory use is high but worker has no data to store to disk...

So I guess I have a few questions:

  1. Why is the dask array losing the chunking when stacked?
  2. Is there a more efficient way to stack all of these arrays to parallelize this process?

From the comments, this is what big array is:

[<xarray.DataArray 'streamflow' (feature_id: 2729077)>
dask.array<shape=(2729077,), dtype=float64, chunksize=(50000,)>
Coordinates:
  * feature_id  (feature_id) int32 101 179 181 183 185 843 845 847 849 851 ...
Attributes:
    long_name:    River Flow
    units:        m3 s-1
    coordinates:  latitude longitude
    valid_range:  [       0 50000000], <xarray.DataArray 'streamflow' (feature_id: 2729077)>
dask.array<shape=(2729077,), dtype=float64, chunksize=(50000,)>
Coordinates:
  * feature_id  (feature_id) int32 101 179 181 183 185 843 845 847 849 851 ...
Attributes:
    long_name:    River Flow
    units:        m3 s-1
    coordinates:  latitude longitude
    valid_range:  [       0 50000000], <xarray.DataArray 'streamflow' (feature_id: 2729077)>
dask.array<shape=(2729077,), dtype=float64, chunksize=(50000,)>
Coordinates:
  * feature_id  (feature_id) int32 101 179 181 183 185 843 845 847 849 851 ...
Attributes:
    long_name:    River Flow
    units:        m3 s-1
    coordinates:  latitude longitude
    valid_range:  [       0 50000000], <xarray.DataArray 'streamflow' (feature_id: 2729077)>
dask.array<shape=(2729077,), dtype=float64, chunksize=(50000,)>
Coordinates:
  * feature_id  (feature_id) int32 101 179 181 183 185 843 845 847 849 851 ...
Attributes:
    long_name:    River Flow
    units:        m3 s-1
    coordinates:  latitude longitude
    valid_range:  [       0 50000000], <xarray.DataArray 'streamflow' (feature_id: 2729077)>
dask.array<shape=(2729077,), dtype=float64, chunksize=(50000,)>
Coordinates:
  * feature_id  (feature_id) int32 101 179 181 183 185 843 845 847 849 851 ...
Attributes:
    long_name:    River Flow
    units:        m3 s-1
    coordinates:  latitude longitude
    valid_range:  [       0 50000000], <xarray.DataArray 'streamflow' (feature_id: 2729077)>
dask.array<shape=(2729077,), dtype=float64, chunksize=(50000,)>
Coordinates:
  * feature_id  (feature_id) int32 101 179 181 183 185 843 845 847 849 851 ...
Attributes:
    long_name:    River Flow
    units:        m3 s-1
    coordinates:  latitude longitude
    valid_range:  [       0 50000000]]
like image 748
pythonweb Avatar asked Mar 05 '23 10:03

pythonweb


1 Answers

The problem here is that dask.array.stack() doesn't recognize xarray.DataArray object as holding dask arrays, so it converts them all to NumPy arrays instead. This is how you end up exhausting your memory.

You could fix this in several different possible ways:

  1. Call dask.array.stack() on a list of dask array, e.g., switch big_array.append(ds.streamflow) to big_array.append(ds.streamflow.data).
  2. Use xarray.concat() instead of dask.array.stack(), e.g., writing dask_big_array = xarray.concat(big_array, dim='time').
  3. Use xarray.open_mfdataset() which combines the process of opening many files and stacking them together, e.g., replacing all of your logic here with xarray.open_mfdataset('/path/to/directory/of/files/*').
like image 164
shoyer Avatar answered Mar 21 '23 01:03

shoyer