Skip to main content

Awesome JAX

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!



  • Neural Network Libraries
    • Flax - Centered on flexibility and clarity.
  • Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
  • Objax - Has an object oriented design similar to PyTorch.
  • Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
  • Trax - "Batteries included" deep learning library focused on providing solutions for common workloads.
  • Jraph - Lightweight graph neural network library.
  • Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
  • HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
  • Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
  • NumPyro - Probabilistic programming based on the Pyro library.
  • Chex - Utilities to write and test reliable JAX code.
  • Optax - Gradient processing and optimization library.
  • RLax - Library for implementing reinforcement learning agents.
  • JAX, M.D. - Accelerated, differential molecular dynamics.
  • Coax - Turn RL papers into code, the easy way.
  • Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
  • cvxpylayers - Construct differentiable convex optimization layers.
  • TensorLy - Tensor learning made simple.
  • NetKet - Machine Learning toolbox for Quantum Physics.

New Libraries

This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.

  • Neural Network Libraries
    • FedJAX - Federated learning in JAX, built on Optax and Haiku.
  • Equivariant MLP - Construct equivariant neural network layers.
  • jax-resnet - Implementations and checkpoints for ResNet variants in Flax.
  • Parallax - Immutable Torch Modules for JAX.
  • jax-unirep - Library implementing the UniRep model for protein machine learning applications.
  • jax-flows - Normalizing flows in JAX.
  • sklearn-jax-kernels - scikit-learn kernel matrices using JAX.
  • jax-cosmo - Differentiable cosmology library.
  • efax - Exponential Families in JAX.
  • mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
  • imax - Image augmentations and transformations.
  • FlaxVision - Flax version of TorchVision.
  • Oryx - Probabilistic programming language based on program transformations.
  • Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
  • delta PV - A photovoltaic simulator with automatic differentation.
  • jaxlie - Lie theory library for rigid body transformations and optimization.
  • BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
  • flaxmodels - Pretrained models for Jax/Flax.
  • CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing.
  • exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
  • JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
  • PIX - PIX is an image processing library in JAX, for JAX.
  • bayex - Bayesian Optimization powered by JAX.
  • JaxDF - Framework for differentiable simulators with arbitrary discretizations.
  • tree-math - Convert functions that operate on arrays into functions that operate on PyTrees.
  • jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX.
  • PGMax - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
  • EvoJAX - Hardware-Accelerated Neuroevolution
  • evosax - JAX-Based Evolution Strategies
  • SymJAX - Symbolic CPU/GPU/TPU programming.
  • mcx - Express & compile probabilistic programs for performant inference.
  • Einshape - DSL-based reshaping library for JAX and other frameworks.
  • ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
  • Diffrax - Numerical differential equation solvers in JAX.
  • tinygp - The tiniest of Gaussian process libraries in JAX.

Models and Projects





  • Reformer - Implementation of the Reformer (efficient transformer) architecture.



This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

Tutorials and Blog Posts



Contributions welcome! Read the contribution guidelines first.

Contribute to this list: