Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to generate random numbers between 0 and 1 in jax?

Tags:

jax

How can I generate random numbers between 0 and 1 in jax? Basically I am looking to replicate the following function from numpy in jax.

np.random.random(1000)
like image 846
Bunny Rabbit Avatar asked Oct 20 '25 17:10

Bunny Rabbit


1 Answers

The equivalent in jax would be

from jax import random
key = random.PRNGKey(758493)  # Random seed is explicit in JAX
random.uniform(key, shape=(1000,))

For more information, see the documentation of the jax.random module.

Also be aware that JAX's random number generator does not maintain any sort of global state, so you'll need to think about it a bit differently than you may be accustomed to in NumPy. For more background on this, see JAX Sharp Bits: Random Numbers.

like image 124
jakevdp Avatar answered Oct 24 '25 16:10

jakevdp



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!