Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filter by index and flattened in numpy, like tf.sequence_mask

I would like to filter my array 2D with an index and then flat this array only with values in the filter. This is pretty much what tf.sequence_mask would do but I would need this in numpy or another light library.

Thanks!

PD: This is an example:

array_2d = [[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]] # this is a numpy array
array_len = [6,5,3]
expected_output = [0,1,2,3,4,5,8,9,10,11,12,21,22,21]
like image 609
Isaac Avatar asked Jan 31 '26 18:01

Isaac


1 Answers

Here's a vectorized solution, using a boolean mask to index array_2d:

array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]]) 
array_len = [6,5,3]

m = ~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T
array_2d[m]
array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])

Details

The mask is created taking the cumsum over an ndarray of ones of the same shape as array_2d, and performing a row-wise comparisson to see which elements are greater than array_len.

So the first step is to create the following ndarray:

np.ones(array_2d.shape).cumsum(axis=1)

array([[1., 2., 3., 4., 5., 6.],
       [1., 2., 3., 4., 5., 6.],
       [1., 2., 3., 4., 5., 6.]])

And perform a row-wise comparisson with array_len:

~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T

array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True, False],
       [ True,  True,  True, False, False, False]])

Then you simply have to filter the array with:

array_2d[m]
array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])
like image 95
yatu Avatar answered Feb 02 '26 11:02

yatu