Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the difference between JAX, Trax, and TensorRT, in simple terms?

I have been using TensorRT and TensorFlow-TRT to accelerate the inference of my DL algorithms.

Then I have heard of:

  • JAX https://github.com/google/jax
  • Trax https://github.com/google/trax

Both seem to accelerate DL. But I am having a hard time to understand them. Can anyone explain them in simple terms?

like image 346
Aizzaac Avatar asked Mar 03 '23 16:03

Aizzaac


1 Answers

Trax is a deep learning framework created by Google and extensively used by the Google Brain team. It comes as an alternative to TensorFlow and PyTorch when it comes to implementing off-the-shelf state of the art deep learning models, for example Transformers, Bert etc. , in principle with respect to the Natural Language Processing field.

Trax is built upon TensorFlow and JAX. JAX is an enhanced and optimised version of Numpy. The important distinction about JAX and NumPy is that the former using a library called XLA (advanced linear algebra) which allows to run your NumPy code on GPU and TPU rather than on CPU like it happens in the plain NumPy, thus speeding up computation.

like image 173
Timbus Calin Avatar answered May 24 '23 02:05

Timbus Calin