Lets suppose I have some function which returns a sum of inputs.
@jit
def some_func(a,r1,r2):
return a + r1 + r2
Now I would like to loop over different values of r1 and r2, save the result and add it to a counter. This is what I mean:
a = 0
r1 = jnp.arange(0,3)
r2 = jnp.arange(0,3)
s = 0
for i in range(len(r1)):
for j in range(len(r2)):
s+= some_func(a, r1[i], r2[j])
print(s)
DeviceArray(18, dtype=int32)
My question is, how do I do this with jax.vmap to avoid writing the for loops? I have something like this so far:
vmap(some_func, in_axes=(None, 0,0), out_axes=0)(jnp.arange(0,3), jnp.arange(0,3))
but this gives me the following error:
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (None, 0, 0) for value tree PyTreeDef((*, *)).
I have a feeling that the error is in in_axes but I am not sure how to get vmap to pick a value for r1 loop over r2 and then do the same for all r1 whilst saving intermediate results.
Any help is appreciated.
vmap will map over a single axis at a time. Because you want to map over two different axes, you'll need two vmap calls:
func_mapped = vmap(vmap(some_func, (None, 0, None)), (None, None, 0))
func_mapped(a, r1, r2).sum()
# 18
Alternatively, for a simple function like this you can avoid vmap and use numpy-style broadcasting to get the same result:
some_func(a, r1[None, :, None], r2[None, None, :]).sum()
# 18
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With