Suppose I have a list that contains a bunch of numpy
ndarrays (or even torch
Tensors):
a, b, c = np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)
collection = [a, b, c]
Now if I was to check whether array b
was in collection
(assuming I don't know what arrays exist in collection
), then attempting: b in collection
spits out the following error:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
And the same would apply to tuples containing arrays.
One way around this was to do a list comprehension:
True in [(b == x).all() for x in collection]
However this requires a for
loop and I was wondering if there was a more "efficient" way to accomplish this?
I would use numpy arrays the whole way through:
import numpy as np
a, b, c = np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)
array = np.dstack([a, b, c])
print(array.shape)
# (3, 3, 3)
np.all(array == b, axis=1).all(axis=1).any()
# True
You can stack arbitrary shaped tensors along axis=0
in a numpy array and then compare all remaining axes at once with np.all
(this is just a slightly clearer version of PaulH's answer):
ugly_shaped_tensor_list = [np.random.rand(3,7,5,3) for j in range(10)]
known_tensor = ugly_shaped_tensor_list[1]
# stack all tensors in a single array along axis=0:
tensor_stack = np.stack(ugly_shaped_tensor_list)
# compare all axes except the "list" axis, 0:
matches = np.all(tensor_stack == known_tensor, axis=(1,2,3,4))
# array([False, True, False, False, False, False, False, False, False, False])
matches.any()
# True
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With