Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Select all rows from Numpy array where each column satisfies some condition

I have an array x of the form,

x = [[1,2,3,...,7,8,9],
[1,2,3,...,7,9,8],
...,
[9,8,7,...,3,1,2],
[9,8,7,...,3,2,1]]

I also have an array of non-allowed numbers for each column. I want to select all of the rows which only have allowed characters in each column. For instance, I might have that I want only rows which do not have any of [1,2,3] in the first column; I can do this by,

x[~np.in1d(x[:,0], [1,2,3])]

And for any single column, I can do this. But I'm looking to essentially do this for all columns at once, selecting only the rows for which every elemnt is an allowed number for its column. I can't seem to get x.any or x.all to do this well - how should I go about this?

EDIT: To clarify, the non-allowed numbers are different for each column. In actuality, I will have some array y,

y = [[1,4,...,7,8],
[2,5,...,9,4],
[3,6,...,8,6]]

Where I want rows from x for which column 1 cannot be in [1,2,3], column 2 cannot be in [4,5,6], and so on.

like image 465
Vedvart1 Avatar asked Oct 28 '25 07:10

Vedvart1


1 Answers

You can broadcast the comparison, then all to check:

x[(x != y[:,None,:]).all(axis=(0,-1))]

Break down:

# compare each element of `x` to each element of `y`
# mask.shape == (y.shape[0], x.shape[0], x.shape[1])
mask = (x != y[:,None,:])

# `all(0)` checks, for each element in `x`, it doesn't match any element in the same column of `y`
# `all(-1) checks along the rows of `x`
mask = mask.all(axis=(0,-1)

# slice
x[mask]

For example, consider:

x = np. array([[1, 2],
       [9, 8],
       [5, 6],
       [7, 8]])

y = np.array([[1, 4],
       [2, 5],
       [3, 7]])

Then mask = (x != y[:,None,:]).all(axis=(0,1)) gives

array([False,  True,  True,  True])
like image 178
Quang Hoang Avatar answered Oct 31 '25 12:10

Quang Hoang