Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to compare an array against a list of arrays?

Suppose I have a list that contains a bunch of numpy ndarrays (or even torch Tensors):

a, b, c = np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)
collection = [a, b, c]

Now if I was to check whether array b was in collection (assuming I don't know what arrays exist in collection), then attempting: b in collection spits out the following error:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

And the same would apply to tuples containing arrays.

One way around this was to do a list comprehension:

True in [(b == x).all() for x in collection]

However this requires a for loop and I was wondering if there was a more "efficient" way to accomplish this?

like image 868
Omar AlSuwaidi Avatar asked Feb 15 '21 22:02

Omar AlSuwaidi


2 Answers

I would use numpy arrays the whole way through:

import numpy as np
a, b, c = np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)
array = np.dstack([a, b, c])
print(array.shape)
# (3, 3, 3)
np.all(array == b, axis=1).all(axis=1).any()
# True
like image 171
Paul H Avatar answered Oct 05 '22 15:10

Paul H


You can stack arbitrary shaped tensors along axis=0 in a numpy array and then compare all remaining axes at once with np.all (this is just a slightly clearer version of PaulH's answer):

ugly_shaped_tensor_list = [np.random.rand(3,7,5,3) for j in range(10)]
known_tensor = ugly_shaped_tensor_list[1]

# stack all tensors in a single array along axis=0:
tensor_stack = np.stack(ugly_shaped_tensor_list)

# compare all axes except the "list" axis, 0:
matches = np.all(tensor_stack == known_tensor, axis=(1,2,3,4))
# array([False,  True, False, False, False, False, False, False, False, False])
matches.any()
# True
like image 33
anon01 Avatar answered Oct 05 '22 15:10

anon01