I am quite new to jax. I am trying to make use of it to do some optimization work. I have tried using a CPU-only version jax, and it has worked well. Indeed the speed is not impressive as expected, so I am planning to move to the GPU version (my laptop has a RTX graphic card). However, I have been trying installing it using anaconda prompt for a whole day with different installation codes, failing with various errors. I tried to use some codes for installation like
conda create -n GPU_folder python==3.11
conda activate GPU_folder
conda install nvidia/label/cuda-12.1.0::cuda
conda install -c anaconda cudnn=8.9
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
However, I am very unsure if the compatibility of jax, jaxlib, etc. is all right. I would bee very happy to have your kind advice on mature installation commands (very depressingly, it is really difficult for me to find the correct commands). I am planning to use Spyder for jax, so I would also appreciate any kind suggestion on if it may work. Thank you very much!
The cuda12_pip extra for JAX has been removed as of JAX v0.6.0; the current GPU installation instructions can be found at https://docs.jax.dev/en/latest/installation.html#nvidia-gpu.
In particular, this is the way to install the current JAX release with CUDA 12 support:
pip install --upgrade "jax[cuda12]"
Once this is installed, you can check that you're correctly connected to a GPU device like this:
import jax
print(jax.devices())
If the output shows CUDA devices (e.g. [CudaDevice(id=0)]) then JAX is connected to your GPU device.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With