I'm new to JAX and writing code that JIT compiles is proving to be quite hard for me. I am trying to achieve the following:
Given an (n,n) array mat in JAX, I would like to add a (1,n) or an (n,1) array to an arbitrary row or column, respectively, of the original array mat.
If I wanted to add a row array, r, to the third row, the numpy equivalent would be,
# if mat is a numpy array
mat[2,:] = mat[2,:] + r
The only way I know how to update an element of an array in JAX is using array.at[i].set(). I am not sure how one can use this to update a row or a column without explicitly using a for-loop.
JAX arrays are immutable, so you cannot do in-place modifications of array entries. But you can accomplish similar results with the np.ndarray.at syntax. For example, the equivalent of
mat[2,:] = mat[2,:] + r
would be
mat = mat.at[2,:].set(mat[2,:] + r)
But you can use the add method to be more efficient in this case:
mat = mat.at[2:].add(r)
Here is an example of adding a row and column array to a 2D array:
import jax.numpy as jnp
mat = jnp.zeros((5, 5))
# Create 2D row & col arrays, as in question
row = jnp.ones(5).reshape(1, 5)
col = jnp.ones(5).reshape(5, 1)
mat = mat.at[1:2, :].add(row)
mat = mat.at[:, 2:3].add(col)
print(mat)
# [[0. 0. 1. 0. 0.]
#  [1. 1. 2. 1. 1.]
#  [0. 0. 1. 0. 0.]
#  [0. 0. 1. 0. 0.]
#  [0. 0. 1. 0. 0.]]
See JAX Sharp Bits: In-Place Updates for more discussion of this.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With