Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

multiprocessing with Xarray and Numpy array

So I am trying to implement a solution that was already described here, but I am changing it a bit. Instead of just trying to change the array with operations, I am trying to read from a NetCDF file using xarray and then write to a shared numpy array with the multiprocessing module.

I feel as though I am getting pretty close, but something is going wrong. I have pasted a reproducible, easy copy/paste example below. As you can see, when I run the processes, they can all read the files that I created, but they do not correctly update the shared numpy array that I am trying to write to. Any help would be appreciated.

Code

import ctypes
import logging
import multiprocessing as mp
import xarray as xr

from contextlib import closing

import numpy as np

info = mp.get_logger().info


def main():

    data = np.arange(10)

    for i in range(4):
        ds = xr.Dataset({'x': data})
        ds.to_netcdf('test_{}.nc'.format(i))

        ds.close()


    logger = mp.log_to_stderr()
    logger.setLevel(logging.INFO)

    # create shared array
    N, M = 4, 10
    shared_arr = mp.Array(ctypes.c_float, N * M)
    arr = tonumpyarray(shared_arr, dtype=np.float32)
    arr = arr.reshape((N, M))

    # Fill with random values
    arr[:, :] = np.zeros((N, M))
    arr_orig = arr.copy()

    files = ['test_0.nc', 'test_1.nc', 'test_2.nc', 'test_3.nc']

    parameter_tuples = [
        (files[0], 0),
        (files[1], 1),
        (files[2], 2),
        (files[3], 3)
    ]

    # write to arr from different processes
    with closing(mp.Pool(initializer=init, initargs=(shared_arr,))) as p:
        # many processes access different slices of the same array
        p.map_async(g, parameter_tuples)
    p.join()

    print(arr_orig)
    print(tonumpyarray(shared_arr, np.float32).reshape(N, M))


def init(shared_arr_):
    global shared_arr
    shared_arr = shared_arr_  # must be inherited, not passed as an argument


def tonumpyarray(mp_arr, dtype=np.float64):
    return np.frombuffer(mp_arr.get_obj(), dtype)


def g(params):
    """no synchronization."""
    print("Current File Name: ", params[0])

    tmp_dataset = xr.open_dataset(params[0])

    print(tmp_dataset["x"].data[:])

    arr = tonumpyarray(shared_arr)
    arr[params[1], :] = tmp_dataset["x"].data[:]

    tmp_dataset.close()


if __name__ == '__main__':
    mp.freeze_support()
    main()
like image 494
pythonweb Avatar asked May 13 '26 19:05

pythonweb


1 Answers

What's wrong?

1.You forgot to reshape back after tonumpyarray.
2.You used the wrong dtype in tonumpyarray.

Code

import ctypes
import logging
import multiprocessing as mp
import xarray as xr

from contextlib import closing

import numpy as np

info = mp.get_logger().info


def main():

    data = np.arange(10)

    for i in range(4):
        ds = xr.Dataset({'x': data})
        ds.to_netcdf('test_{}.nc'.format(i))

        ds.close()


    logger = mp.log_to_stderr()
    logger.setLevel(logging.INFO)

    # create shared array
    N, M = 4, 10
    shared_arr = mp.Array(ctypes.c_float, N * M)
    arr = tonumpyarray(shared_arr, dtype=np.float32)
    arr = arr.reshape((N, M))

    # Fill with random values
    arr[:, :] = np.zeros((N, M))
    arr_orig = arr.copy()

    files = ['test_0.nc', 'test_1.nc', 'test_2.nc', 'test_3.nc']

    parameter_tuples = [
        (files[0], 0),
        (files[1], 1),
        (files[2], 2),
        (files[3], 3)
    ]

    # write to arr from different processes
    with closing(mp.Pool(initializer=init, initargs=(shared_arr, N, M))) as p:
        # many processes access different slices of the same array
        p.map_async(g, parameter_tuples)
    p.join()

    print(arr_orig)
    print(tonumpyarray(shared_arr, np.float32).reshape(N, M))


def init(shared_arr_, N_, M_):    # add shape
    global shared_arr
    global N, M
    shared_arr = shared_arr_  # must be inherited, not passed as an argument
    N = N_
    M = M_


def tonumpyarray(mp_arr, dtype=np.float32):  # change type
    return np.frombuffer(mp_arr.get_obj(), dtype)


def g(params):
    """no synchronization."""
    print("Current File Name: ", params[0])

    tmp_dataset = xr.open_dataset(params[0])

    print(tmp_dataset["x"].data[:])

    arr = tonumpyarray(shared_arr).reshape(N, M)   # reshape
    arr[params[1], :] = tmp_dataset["x"].data[:]

    tmp_dataset.close()


if __name__ == '__main__':
    mp.freeze_support()
    main()
like image 57
Hao Li Avatar answered May 15 '26 09:05

Hao Li



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!