Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Updating entire row or column of a 2D array in JAX

I'm new to JAX and writing code that JIT compiles is proving to be quite hard for me. I am trying to achieve the following:

Given an (n,n) array mat in JAX, I would like to add a (1,n) or an (n,1) array to an arbitrary row or column, respectively, of the original array mat.

If I wanted to add a row array, r, to the third row, the numpy equivalent would be,

# if mat is a numpy array
mat[2,:] = mat[2,:] + r

The only way I know how to update an element of an array in JAX is using array.at[i].set(). I am not sure how one can use this to update a row or a column without explicitly using a for-loop.

like image 723
CartesianBear Avatar asked Oct 23 '25 14:10

CartesianBear


1 Answers

JAX arrays are immutable, so you cannot do in-place modifications of array entries. But you can accomplish similar results with the np.ndarray.at syntax. For example, the equivalent of

mat[2,:] = mat[2,:] + r

would be

mat = mat.at[2,:].set(mat[2,:] + r)

But you can use the add method to be more efficient in this case:

mat = mat.at[2:].add(r)

Here is an example of adding a row and column array to a 2D array:

import jax.numpy as jnp

mat = jnp.zeros((5, 5))

# Create 2D row & col arrays, as in question
row = jnp.ones(5).reshape(1, 5)
col = jnp.ones(5).reshape(5, 1)

mat = mat.at[1:2, :].add(row)
mat = mat.at[:, 2:3].add(col)

print(mat)
# [[0. 0. 1. 0. 0.]
#  [1. 1. 2. 1. 1.]
#  [0. 0. 1. 0. 0.]
#  [0. 0. 1. 0. 0.]
#  [0. 0. 1. 0. 0.]]

See JAX Sharp Bits: In-Place Updates for more discussion of this.

like image 109
jakevdp Avatar answered Oct 26 '25 04:10

jakevdp



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!