JAX

From HPCWIKI
Revision as of 17:10, 25 March 2023 by Admin (talk | contribs) (새 문서: JAX stands for "Just Another XLA" where XLA stands for Accelerated Linear Algebra, which is a domain-specific compiler for linear algebra operations developed by Google. JAX provides a familiar NumPy-like API for array operations and supports automatic differentiation for arbitrary code, including control flow statements like loops and conditionals. This makes it easy to write and optimize complex numerical computations, including deep learning models, with high performance on...)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

JAX stands for "Just Another XLA" where XLA stands for Accelerated Linear Algebra, which is a domain-specific compiler for linear algebra operations developed by Google.

JAX provides a familiar NumPy-like API for array operations and supports automatic differentiation for arbitrary code, including control flow statements like loops and conditionals. This makes it easy to write and optimize complex numerical computations, including deep learning models, with high performance on CPU, GPU, and TPU hardware.

JAX also provides a functional programming paradigm, enabling users to compose functions in a pure, side-effect-free way. This can make it easier to reason about the behavior of a program and to reuse code.


JAX features

  • as of March 2023, JAX has no dataset/dataloader API, nor standard datasets like MNIST. Thus you will have to use either TF or PyTorch for these tasks or implement everything by yourself.
  • JAX is very TPU friendly and has built-in support for multiple devices
  • It is still in the 0.x versions, the API might change.
  • Functional programming can be annoying for beginners.
  • Apart from the TPU, there are few real advantages over PyTorch (or TF).