How can I generate random numbers between 0 and 1 in jax?
Basically I am looking to replicate the following function from numpy in jax.
np.random.random(1000)
The equivalent in jax would be
from jax import random
key = random.PRNGKey(758493) # Random seed is explicit in JAX
random.uniform(key, shape=(1000,))
For more information, see the documentation of the jax.random module.
Also be aware that JAX's random number generator does not maintain any sort of global state, so you'll need to think about it a bit differently than you may be accustomed to in NumPy. For more background on this, see JAX Sharp Bits: Random Numbers.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With