Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use Jax vmap over zipped arguments?

Tags:

python

jax

I have the following example code that works with a regular map

def f(x_y):
    x, y = x_y
    return x.sum() + y.sum()

xs = [jnp.zeros(3) for i in range(4)]
ys = [jnp.zeros(2) for i in range(4)]

list(map(f, zip(xs, ys)))

# returns:
[DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32)]

How can I use jax.vmap instead? The naive thing is:

jax.vmap(f)(zip(xs, ys))

but this gives:

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
like image 498
marius Avatar asked Oct 19 '25 02:10

marius


2 Answers

For using jax.vmap, you do not need to zip your variables. You can write what you want like below:

import jax.numpy as jnp
from jax import vmap

def f(x_y):
    x, y = x_y
    return x.sum() + y.sum()

xs = jnp.zeros((4,3))
ys = jnp.zeros((4,2))
vmap(f)((xs, ys))

Output:

DeviceArray([0., 0., 0., 0.], dtype=float32)
like image 160
I'mahdi Avatar answered Oct 20 '25 15:10

I'mahdi


vmap is designed to map over multiple variables by default, so no zip is needed. Furthermore, it can only map over array axes, not over elements of lists or tuples. So a more canonical way to to write your example would be to convert your lists to arrays and do something like this:

def g(x, y):
  return x.sum() + y.sum()

xs_arr = jnp.asarray(xs)
ys_arr = jnp.asarray(ys)

jax.vmap(g)(xs_arr, ys_arr)
# DeviceArray([0., 0., 0., 0.], dtype=float32)
like image 28
jakevdp Avatar answered Oct 20 '25 17:10

jakevdp