Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

implementing if-then-elif-then-else in jax

Tags:

python

cudnn

jax

I'm just starting to use JAX, and I wonder—what would be the right way to implement if-then-elif-then-else in JAX/Python? For example, given input arrays: n = [5, 4, 3, 2] and k = [3, 3, 3, 3], I need to implement the following pseudo-code:

def n_choose_k_safe(n, k):
    r = jnp.empty(4)
    for i in range(4):
        if n[i] < k[i]:
            r[i] = 0
        elif n[i] == k[i]:
            r[i] = 1
        else:
            r[i] = func_nchoosek(n[i], k[i])
    return r

There are so many choices like vmap, lax.select, lax.where, jax.cond, lax.fori_loop, etc., so that it is hard to decide on specific combinations of the utilities to use. By the way, k can be a scalar (if that makes it simpler).

like image 792
Terry Avatar asked Feb 04 '26 04:02

Terry


1 Answers

There's a slightly more compact way to express the solution in Valentin's answer, using jax.numpy.select:

def n_choose_k_safe(n, k):
  return jnp.select(condlist=[n > k, n == k],
                    choicelist=[jnp.vectorize(func_nchoosek)(n, k), 1],
                    default=0)

For input arrays of length 4, this should return the same result as your original code, assuming func_nchoosek is compatible with jax.vmap. Using vectorize here in place of vmap will make the function also compatible with scalar inputs for k, without having to manually set the in_axes argument.

like image 110
jakevdp Avatar answered Feb 05 '26 16:02

jakevdp