Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficiently count zero elements in numpy array?

I need to count the number of zero elements in numpy arrays. I'm aware of the numpy.count_nonzero function, but there appears to be no analog for counting zero elements.

My arrays are not very large (typically less than 1E5 elements) but the operation is performed several millions of times.

Of course I could use len(arr) - np.count_nonzero(arr), but I wonder if there's a more efficient way to do it.

Here's a MWE of how I do it currently:

import numpy as np import timeit  arrs = [] for _ in range(1000):     arrs.append(np.random.randint(-5, 5, 10000))   def func1():     for arr in arrs:         zero_els = len(arr) - np.count_nonzero(arr)   print(timeit.timeit(func1, number=10)) 
like image 952
Gabriel Avatar asked Mar 21 '17 00:03

Gabriel


People also ask

How do you count zero elements in a NumPy array?

To count all the zeros in an array, simply use the np. count_nonzero() function checking for zeros. It returns the count of elements inside the array satisfying the condition (in this case, if it's zero or not).

How do you count nonzero elements in a NumPy array?

count_nonzero() function counts the number of non-zero values in the array arr. Parameters : arr : [array_like] The array for which to count non-zeros. axis : [int or tuple, optional] Axis or tuple of axes along which to count non-zeros.

How do you count the number of zeros in Python?

Python List Count Zero / Non-Zero. To count the number of zeros in a given list, use the list. count(0) method call.


1 Answers

A 2x faster approach would be to just use np.count_nonzero() but with the condition as needed.

In [3]: arr Out[3]:  array([[1, 2, 0, 3],       [3, 9, 0, 4]])  In [4]: np.count_nonzero(arr==0) Out[4]: 2  In [5]:def func_cnt():             for arr in arrs:                 zero_els = np.count_nonzero(arr==0)                 # here, it counts the frequency of zeroes actually 

You can also use np.where() but it's slower than np.count_nonzero()

In [6]: np.where( arr == 0) Out[6]: (array([0, 1]), array([2, 2]))  In [7]: len(np.where( arr == 0)) Out[7]: 2 

Efficiency: (in descending order)

In [8]: %timeit func_cnt() 10 loops, best of 3: 29.2 ms per loop  In [9]: %timeit func1() 10 loops, best of 3: 46.5 ms per loop  In [10]: %timeit func_where() 10 loops, best of 3: 61.2 ms per loop 

more speedups with accelerators

It is now possible to achieve more than 3 orders of magnitude speed boost with the help of JAX if you've access to accelerators (GPU/TPU). Another advantage of using JAX is that the NumPy code needs very little modification to make it JAX compatible. Below is a reproducible example:

In [1]: import jax.numpy as jnp In [2]: from jax import jit  # set up inputs In [3]: arrs = [] In [4]: for _ in range(1000):    ...:     arrs.append(np.random.randint(-5, 5, 10000))  # JIT'd function that performs the counting task In [5]: @jit    ...: def func_cnt():    ...:     for arr in arrs:    ...:         zero_els = jnp.count_nonzero(arr==0) 

# efficiency test In [8]: %timeit func_cnt() 15.6 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 
like image 111
kmario23 Avatar answered Oct 02 '22 09:10

kmario23