Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Execute function specifically on CPU in Jax

I have a function that will instantiate a huge array and do other things. I am running my code on TPUs so my memory is limited.

How can I execute my function specifically on the CPU?

If I do:

y = jax.device_put(my_function(), device=jax.devices("cpu")[0])

I guess that my_function() is first executed on TPU and the result is put on CPU, which gives me memory error.

and using jax.config.update('jax_platform_name', 'cpu') at the beginning of my code seems to have no effect.

Also please note that I can't modify my_function()

Thanks!

like image 746
Valentin Macé Avatar asked Oct 17 '25 05:10

Valentin Macé


2 Answers

To directly specify the device on which a function should be executed, use the device argument of jax.jit. For example (using a GPU runtime because it's the accelerator I have access to at the moment):

import jax

gpu_device = jax.devices('gpu')[0]
cpu_device = jax.devices('cpu')[0]

def my_function(x):
  return x.sum()

x = jax.numpy.arange(10)

x_gpu = jax.jit(my_function, device=gpu_device)(x)
print(x_gpu.device())
# gpu:0

x_cpu = jax.jit(my_function, device=cpu_device)(x)
print(x_cpu.device())
# TFRT_CPU_0

This can also be controlled with the jax.default_device decorator around the call-site:

with jax.default_device(cpu_device):
  print(jax.jit(my_function)(x).device())
  # TFRT_CPU_0

with jax.default_device(gpu_device):
  print(jax.jit(my_function)(x).device())
  # gpu:0
like image 59
jakevdp Avatar answered Oct 18 '25 18:10

jakevdp


I'm going to make a guess here. I can't run it either so you may have to fiddle with it

with jax.default_device(jax.devices("cpu")[0]):
    y = my_function()

See the docs here and here.

like image 22
joel Avatar answered Oct 18 '25 17:10

joel



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!