Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

GPU supported Jax Installation [closed]

I am quite new to jax. I am trying to make use of it to do some optimization work. I have tried using a CPU-only version jax, and it has worked well. Indeed the speed is not impressive as expected, so I am planning to move to the GPU version (my laptop has a RTX graphic card). However, I have been trying installing it using anaconda prompt for a whole day with different installation codes, failing with various errors. I tried to use some codes for installation like

conda create -n GPU_folder python==3.11
conda activate GPU_folder
conda install nvidia/label/cuda-12.1.0::cuda
conda install -c anaconda cudnn=8.9
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

However, I am very unsure if the compatibility of jax, jaxlib, etc. is all right. I would bee very happy to have your kind advice on mature installation commands (very depressingly, it is really difficult for me to find the correct commands). I am planning to use Spyder for jax, so I would also appreciate any kind suggestion on if it may work. Thank you very much!

like image 296
Newbee Avatar asked Feb 18 '26 11:02

Newbee


1 Answers

The cuda12_pip extra for JAX has been removed as of JAX v0.6.0; the current GPU installation instructions can be found at https://docs.jax.dev/en/latest/installation.html#nvidia-gpu.

In particular, this is the way to install the current JAX release with CUDA 12 support:

pip install --upgrade "jax[cuda12]"

Once this is installed, you can check that you're correctly connected to a GPU device like this:

import jax
print(jax.devices())

If the output shows CUDA devices (e.g. [CudaDevice(id=0)]) then JAX is connected to your GPU device.

like image 134
jakevdp Avatar answered Feb 19 '26 23:02

jakevdp



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!