Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to test if all rows are equal in a numpy

Tags:

In numpy, is there a nice idiomatic way of testing if all rows are equal in a 2d array?

I can do something like

np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))]) 

This seems to mix python lists with numpy arrays which is ugly and presumably also slow.

Is there a nicer/neater way?

like image 944
graffe Avatar asked Oct 02 '14 15:10

graffe


People also ask

How do you check if all values the same in NumPy array?

Check if all elements are equal in a 1D Numpy Array using numpy. all() This confirms that all values in the array are the same.

How do you check if all elements in a matrix are equal python?

if len(set(input_list)) == 1: # input_list has all identical elements.

What is the function to check whether two arrays are equal in NumPy?

True if two arrays have the same shape and elements, False otherwise.

Does NumPy do lazy evaluation?

NumPy doesn't do this, so the challenge is to present the same interface as NumPy without explicitly using lazy evaluation.


2 Answers

One way is to check that every row of the array arr is equal to its first row arr[0]:

(arr == arr[0]).all() 

Using equality == is fine for integer values, but if arr contains floating point values you could use np.isclose instead to check for equality within a given tolerance:

np.isclose(a, a[0]).all() 

If your array contains NaN and you want to avoid the tricky NaN != NaN issue, you could combine this approach with np.isnan:

(np.isclose(a, a[0]) | np.isnan(a)).all() 
like image 200
Alex Riley Avatar answered Sep 25 '22 06:09

Alex Riley


Simply check if the number if unique items in the array are 1:

>>> arr = np.array([[1]*10 for _ in xrange(5)]) >>> len(np.unique(arr)) == 1 True 

A solution inspired from unutbu's answer:

>>> arr = np.array([[1]*10 for _ in xrange(5)]) >>> np.all(np.all(arr == arr[0,:], axis = 1)) True 

One problem with your code is that you're creating an entire list first before applying np.all() on it. Due to that there's no short-circuiting happening in your version, instead of that it would be better if you use Python's all() with a generator expression:

Timing comparisons:

>>> M = arr = np.array([[3]*100] + [[2]*100 for _ in xrange(1000)]) >>> %timeit np.all(np.all(arr == arr[0,:], axis = 1)) 1000 loops, best of 3: 272 µs per loop >>> %timeit (np.diff(M, axis=0) == 0).all() 1000 loops, best of 3: 596 µs per loop >>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))]) 100 loops, best of 3: 10.6 ms per loop >>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M))) 100000 loops, best of 3: 11.3 µs per loop  >>> M = arr = np.array([[2]*100 for _ in xrange(1000)]) >>> %timeit np.all(np.all(arr == arr[0,:], axis = 1)) 1000 loops, best of 3: 330 µs per loop >>> %timeit (np.diff(M, axis=0) == 0).all() 1000 loops, best of 3: 594 µs per loop >>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))]) 100 loops, best of 3: 9.51 ms per loop >>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M))) 100 loops, best of 3: 9.44 ms per loop 
like image 29
Ashwini Chaudhary Avatar answered Sep 26 '22 06:09

Ashwini Chaudhary