I am struggling to initialise thread-local ndarrays with cython.parallel
:
Pseudo-code:
cdef:
ndarray buffer
with nogil, parallel():
buffer = np.empty(...)
for i in prange(n):
with gil:
print "Thread %d: data address: 0x%x" % (threadid(), <uintptr_t>buffer.data)
some_func(buffer.data) # use thread-local buffer
cdef void some_func(char * buffer_ptr) nogil:
(... works on buffer contents...)
My problem is that in all threads buffer.data
points to the same address. Namely the address of the thread that last assigned buffer
.
Despite buffer
being assigned within the parallel()
(or alternatively prange
) block, cython does not make buffer
a private
or thread-local variable but keeps it as a shared
variable.
As a result, buffer.data
points to the same memory region wreaking havoc on my algorithm.
This is not a problem exclusively with ndarray objects but seemingly with all cdef class
defined objects.
How do I solve this problem?
I think I have finally found a solution to this problem that I like. The short version is that you create an array that has shape:
(number_of_threads, ...<whatever shape you need in the thread>...)
Then, call openmp.omp_get_thread_num and use that to index the array to get a "thread-local" sub-array. This avoids having a separate array for every loop index (which could be enormous) but also prevents threads overwriting each other.
Here's a rough version of what I did:
import numpy as np
import multiprocessing
from cython.parallel cimport parallel
from cython.parallel import prange
cimport openmp
cdef extern from "stdlib.h":
void free(void* ptr)
void* malloc(size_t size)
void* realloc(void* ptr, size_t size)
...
cdef int num_items = ...
num_threads = multiprocessing.cpu_count()
result_array = np.zeros((num_threads, num_items), dtype=DTYPE) # Make sure each thread uses separate memory
cdef c_numpy.ndarray result_cn
cdef CDTYPE ** result_pointer_arr
result_pointer_arr = <CDTYPE **> malloc(num_threads * sizeof(CDTYPE *))
for i in range(num_threads):
result_cn = result_array[i]
result_pointer_arr[i] = <CDTYPE*> result_cn.data
cdef int thread_number
for i in prange(num_items, nogil=True, chunksize=1, num_threads=num_threads, schedule='static'):
thread_number = openmp.omp_get_thread_num()
some_function(result_pointer_arr[thread_number])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With