I have the following numpy function as seen below that I'm trying to optimize by using JAX but for whatever reason, it's slower.
Could someone point out what I can do to improve the performance here? I suspect it has to do with the list comprehension taking place for Cg_new but breaking that apart doesn't yield any further performance gains in JAX.
import numpy as np
def testFunction_numpy(C, Mi, C_new, Mi_new):
Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = np.zeros((1, len(Mi[0])))
invertCsensor_new = np.linalg.inv(C_new)
Wg_new = np.dot(invertCsensor_new, Mi_new)
Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
Here's the JAX equivalent:
import jax.numpy as jnp
import numpy as np
import jax
def testFunction_JAX(C, Mi, C_new, Mi_new):
Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = jnp.zeros((1, len(Mi[0])))
invertCsensor_new = jnp.linalg.inv(C_new)
Wg_new = jnp.dot(invertCsensor_new, Mi_new)
Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)
jitter = jax.jit(testFunction_JAX)
%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop
For general considerations on benchmark comparisons between JAX and NumPy, see https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
As for your particular code: when JAX jit compilation encounters Python control flow, including list comprehensions, it effectively flattens the loop and stages the full sequence of operations. This can lead to slow jit compile times and suboptimal code. Fortunately, the list comprehension in your function is readily expressed in terms of native numpy broadcasting. Additionally, there are two other improvements you can make:
Wg_new
and Cg_new
before computing themdot(inv(A), B)
, it is much more efficient and precise to use np.linalg.solve
rather than explicitly computing the inverse.Making these three improvements to both the numpy and JAX versions result in the following:
def testFunction_numpy_v2(C, Mi, C_new, Mi_new):
Wg_new = np.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return C_new, Mi_new, Wg_new, Cg_new
@jax.jit
def testFunction_JAX_v2(C, Mi, C_new, Mi_new):
Wg_new = jnp.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return C_new, Mi_new, Wg_new, Cg_new
%timeit testFunction_numpy_v2(C, Mi, C_new, Mi_new)
# 1000 loops, best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax, Mi_jax, C_new_jax, Mi_new_jax)
# 1000 loops, best of 3: 1.35 ms per loop
Both functions are a fair bit faster than they were previously due to the improved implementation. You'll notice, however, that JAX is still slower than numpy here; this is somewhat to be expected because for a function of this level of simplicity, JAX and numpy are both generating effectively the same short series of BLAS and LAPACK calls executed on a CPU architecture. There's simply not much room for improvement over numpy's reference implementation, and with such small arrays JAX's overhead is apparent.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With