JAX

From HPCWIKI
Jump to navigation Jump to search

JAX

섬네일|jaxlogo JAX stands for "Just Another XLA" where XLA stands for Accelerated Linear Algebra, which is a domain-specific compiler for linear algebra operations developed by Google.

JAX provides a familiar NumPy-like API for array operations and supports automatic differentiation for arbitrary code, including control flow statements like loops and conditionals. This makes it easy to write and optimize complex numerical computations, including deep learning models, with high performance on CPU, GPU, and TPU hardware.

JAX also provides a functional programming paradigm, enabling users to compose functions in a pure, side-effect-free way. This can make it easier to reason about the behavior of a program and to reuse code. so, org said JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

JAX limitation as of March 2023

JAX looks not stable yet and the community seems not moving faster than other famous deep learning frameworks yet.

  • No kernel image is available for execution on the device in Google JAX
  • Compatibility Issues- Google JAX is a relatively new library and as such, it may not be compatible with all existing libraries. This can lead to compatibility issues when trying to use JAX alongside other libraries, such as NumPy.
  • Problem: Memory Leak Issues - One of the most commonly reported problems with Google JAX is that of memory leaks. These occur when the program continues to hold on to memory resources even after they are no longer needed. This can lead to the program using up an excessive amount of memory and potentially crashing.
  • Performance Issues - Another common problem with Google JAX is that of poor performance. This can manifest in a number of ways, such as slow execution times or poor GPU utilization

JAX requirements

  • JAX 0.4.6 needs CUDA 11.8 is required for Hopper and Ada Lovelace series GPUs[1]

JAX features[2]

  • as of March 2023, JAX has no dataset/dataloader API, nor standard datasets like MNIST. Thus you will have to use either TF or PyTorch for these tasks or implement everything by yourself.
  • JAX is very TPU friendly and has built-in support for multiple devices
  • It is still in the 0.x versions, the API might change.
  • Functional programming can be annoying for beginners.
  • Apart from the TPU, there are few real advantages over PyTorch (or TF).

Reference