This is a minimal example of the real larger problem I am facing. Consider the function below:
import jax.numpy as jnp
def test(x):
return jnp.sum(x)
I tried to vectorize it by:
v_test = jax.vmap(test)
My inputs to test look like:
x1 = jnp.array([1,2,3])
x2 = jnp.array([4,5,6,7])
x3 = jnp.array([8,9])
x4 = jnp.array([10])
and my input to v_test is:
x = [x1, x2, x3, x4]
If I try:
v_test(x)
I get the error below:
ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
([3, 4, 2, 1],)
Is there a way to vectorize test over a list of unequal length arrays?
I could avoid this by padding so the arrays have the same length, however, padding is not desired.
JAX does not support ragged arrays, (i.e. arrays in which each row has a different number of elements) so there is currently no way to use vmap for this kind of data. Your best bet is probably to use a Python for loop:
y = [test(xi) for xi in x]
Alternatively, you might be able to express the operation you have in mind in terms of segment_sum or similar operations. For example:
segments = jnp.concatenate([i * jnp.ones_like(xi) for i, xi in enumerate(x)])
result = jax.ops.segment_sum(jnp.concatenate(x), segments)
print(result)
# [ 6 22 17 10]
Another possibility is to pad the input arrays so that they can fit into a standard, non-ragged 2D array.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With