Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding whether a list contains a particular numpy array

Tags:

python

numpy

import numpy as np

a = np.eye(2)
b = np.array([1,1],[0,1])

my_list = [a, b]

a in my_list returns true, but b in my_list returns "ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()". I can get around this by converting the arrays to strings or lists first, but is there a nicer (more Pythonic) way of doing it?

like image 887
Chris Midgley Avatar asked Nov 02 '22 18:11

Chris Midgley


1 Answers

The problem is that in numpy the == operator returns an array:

>>> a == b
array([[ True, False],
       [ True,  True]], dtype=bool)

You use .array_equal() to compare arrays to a pure boolean value.

>>> any(np.array_equal(a, x) for x in my_list)
True
>>> any(np.array_equal(b, x) for x in my_list)
True
>>> any(np.array_equal(np.array([a, a]), x) for x in my_list)
False
>>> any(np.array_equal(np.array([[0,0],[0,0]]), x) for x in my_list)
False
like image 76
kennytm Avatar answered Nov 15 '22 04:11

kennytm