I have an array of the shape (2,10) such as:
arr = jnp.ones(shape=(2,10)) * 2
or
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
and another array, for example [2,4].
I want the second array to tell from which index the elements of arr should be masked. Here the result would be:
[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
[2. 2. 2. 2. -1. -1. -1. -1. -1. -1.]]
I need to use jax.numpy and the answer to be vectorized and fast if possible, i.e. not using loops.
You can do this with a vmapped three-term jnp.where statement. For example:
import jax.numpy as jnp
import jax
arr = jnp.ones(shape=(2,10)) * 2
idx = jnp.array([2, 4])
@jax.vmap
def f(row, ind):
return jnp.where(jnp.arange(len(row)) < ind, row, -1)
f(arr, idx)
# DeviceArray([[ 2., 2., -1., -1., -1., -1., -1., -1., -1., -1.],
# [ 2., 2., 2., 2., -1., -1., -1., -1., -1., -1.]], dtype=float32)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With