I need to count the number of zero elements in numpy
arrays. I'm aware of the numpy.count_nonzero function, but there appears to be no analog for counting zero elements.
My arrays are not very large (typically less than 1E5 elements) but the operation is performed several millions of times.
Of course I could use len(arr) - np.count_nonzero(arr)
, but I wonder if there's a more efficient way to do it.
Here's a MWE of how I do it currently:
import numpy as np import timeit arrs = [] for _ in range(1000): arrs.append(np.random.randint(-5, 5, 10000)) def func1(): for arr in arrs: zero_els = len(arr) - np.count_nonzero(arr) print(timeit.timeit(func1, number=10))
To count all the zeros in an array, simply use the np. count_nonzero() function checking for zeros. It returns the count of elements inside the array satisfying the condition (in this case, if it's zero or not).
count_nonzero() function counts the number of non-zero values in the array arr. Parameters : arr : [array_like] The array for which to count non-zeros. axis : [int or tuple, optional] Axis or tuple of axes along which to count non-zeros.
Python List Count Zero / Non-Zero. To count the number of zeros in a given list, use the list. count(0) method call.
A 2x faster approach would be to just use np.count_nonzero()
but with the condition as needed.
In [3]: arr Out[3]: array([[1, 2, 0, 3], [3, 9, 0, 4]]) In [4]: np.count_nonzero(arr==0) Out[4]: 2 In [5]:def func_cnt(): for arr in arrs: zero_els = np.count_nonzero(arr==0) # here, it counts the frequency of zeroes actually
You can also use np.where()
but it's slower than np.count_nonzero()
In [6]: np.where( arr == 0) Out[6]: (array([0, 1]), array([2, 2])) In [7]: len(np.where( arr == 0)) Out[7]: 2
Efficiency: (in descending order)
In [8]: %timeit func_cnt() 10 loops, best of 3: 29.2 ms per loop In [9]: %timeit func1() 10 loops, best of 3: 46.5 ms per loop In [10]: %timeit func_where() 10 loops, best of 3: 61.2 ms per loop
more speedups with accelerators
It is now possible to achieve more than 3 orders of magnitude speed boost with the help of JAX if you've access to accelerators (GPU/TPU). Another advantage of using JAX is that the NumPy code needs very little modification to make it JAX compatible. Below is a reproducible example:
In [1]: import jax.numpy as jnp In [2]: from jax import jit # set up inputs In [3]: arrs = [] In [4]: for _ in range(1000): ...: arrs.append(np.random.randint(-5, 5, 10000)) # JIT'd function that performs the counting task In [5]: @jit ...: def func_cnt(): ...: for arr in arrs: ...: zero_els = jnp.count_nonzero(arr==0)
# efficiency test In [8]: %timeit func_cnt() 15.6 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With