# Awesome JAX

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!

## Contents

## Libraries

- Neural Network Libraries
- Flax - Centered on flexibility and clarity.

- Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
- Objax - Has an object oriented design similar to PyTorch.
- Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
- Trax - "Batteries included" deep learning library focused on providing solutions for common workloads.
- Jraph - Lightweight graph neural network library.
- Neural Tangents - High-level API for specifying neural networks of both finite and
*infinite*width. - HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
- Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
- NumPyro - Probabilistic programming based on the Pyro library.
- Chex - Utilities to write and test reliable JAX code.
- Optax - Gradient processing and optimization library.
- RLax - Library for implementing reinforcement learning agents.
- JAX, M.D. - Accelerated, differential molecular dynamics.
- Coax - Turn RL papers into code, the easy way.
- Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
- cvxpylayers - Construct differentiable convex optimization layers.
- TensorLy - Tensor learning made simple.
- NetKet - Machine Learning toolbox for Quantum Physics.

### New Libraries

This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.

- Neural Network Libraries
- FedJAX - Federated learning in JAX, built on Optax and Haiku.

- Equivariant MLP - Construct equivariant neural network layers.
- jax-resnet - Implementations and checkpoints for ResNet variants in Flax.
- Parallax - Immutable Torch Modules for JAX.
- jax-unirep - Library implementing the UniRep model for protein machine learning applications.
- jax-flows - Normalizing flows in JAX.
- sklearn-jax-kernels -
`scikit-learn`

kernel matrices using JAX. - jax-cosmo - Differentiable cosmology library.
- efax - Exponential Families in JAX.
- mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
- imax - Image augmentations and transformations.
- FlaxVision - Flax version of TorchVision.
- Oryx - Probabilistic programming language based on program transformations.
- Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
- delta PV - A photovoltaic simulator with automatic differentation.
- jaxlie - Lie theory library for rigid body transformations and optimization.
- BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
- flaxmodels - Pretrained models for Jax/Flax.
- CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing.
- exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
- JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
- PIX - PIX is an image processing library in JAX, for JAX.
- bayex - Bayesian Optimization powered by JAX.
- JaxDF - Framework for differentiable simulators with arbitrary discretizations.
- tree-math - Convert functions that operate on arrays into functions that operate on PyTrees.
- jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX.
- PGMax - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
- EvoJAX - Hardware-Accelerated Neuroevolution
- evosax - JAX-Based Evolution Strategies
- SymJAX - Symbolic CPU/GPU/TPU programming.
- mcx - Express & compile probabilistic programs for performant inference.
- Einshape - DSL-based reshaping library for JAX and other frameworks.
- ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in
*ALX: Large Scale Matrix Factorization on TPUs*. - Diffrax - Numerical differential equation solvers in JAX.
- tinygp - The
*tiniest*of Gaussian process libraries in JAX.

## Models and Projects

### JAX

- Fourier Feature Networks - Official implementation of
*Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains*. - kalman-jax - Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.
- GPJax - Gaussian processes in JAX.
- jaxns - Nested sampling in JAX.
- Amortized Bayesian Optimization - Code related to
*Amortized Bayesian Optimization over Discrete Spaces*. - Accurate Quantized Training - Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.
- BNN-HMC - Implementation for the paper
*What Are Bayesian Neural Network Posteriors Really Like?*. - JAX-DFT - One-dimensional density functional theory (DFT) in JAX, with implementation of
*Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics*. - Robust Loss - Reference code for the paper
*A General and Adaptive Robust Loss Function*. - Symbolic Functionals - Demonstration from
*Evolving symbolic density functionals*. - TriMap - Official JAX implementation of
*TriMap: Large-scale Dimensionality Reduction Using Triplets*.

### Flax

- Performer - Flax implementation of the Performer (linear transformer via FAVOR+) architecture.
- JaxNeRF - Implementation of
*NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis*with multi-device GPU/TPU support. - mip-NeRF - Official implementation of
*Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields*. - RegNeRF - Official implementation of
*RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs*. - Big Transfer (BiT) - Implementation of
*Big Transfer (BiT): General Visual Representation Learning*. - JAX RL - Implementations of reinforcement learning algorithms.
- gMLP - Implementation of
*Pay Attention to MLPs*. - MLP Mixer - Minimal implementation of
*MLP-Mixer: An all-MLP Architecture for Vision*. - Distributed Shampoo - Implementation of
*Second Order Optimization Made Practical*. - NesT - Official implementation of
*Aggregating Nested Transformers*. - XMC-GAN - Official implementation of
*Cross-Modal Contrastive Learning for Text-to-Image Generation*. - FNet - Official implementation of
*FNet: Mixing Tokens with Fourier Transforms*. - GFSA - Official implementation of
*Learning Graph Structure With A Finite-State Automaton Layer*. - IPA-GNN - Official implementation of
*Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks*. - Flax Models - Collection of models and methods implemented in Flax.
- Protein LM - Implements BERT and autoregressive models for proteins, as described in
*Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences*and*ProGen: Language Modeling for Protein Generation*. - Slot Attention - Reference implementation for
*Differentiable Patch Selection for Image Recognition*. - Vision Transformer - Official implementation of
*An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale*. - FID computation - Port of mseitzer/pytorch-fid to Flax.
- ARDM - Official implementation of
*Autoregressive Diffusion Models*. - D3PM - Official implementation of
*Structured Denoising Diffusion Models in Discrete State-Spaces*. - Gumbel-max Causal Mechanisms - Code for
*Learning Generalized Gumbel-max Causal Mechanisms*, with extra code in GuyLor/gumbel_max_causal_gadgets_part2. - Latent Programmer - Code for the ICML 2021 paper
*Latent Programmer: Discrete Latent Codes for Program Synthesis*. - SNeRG - Official implementation of
*Baking Neural Radiance Fields for Real-Time View Synthesis*. - Spin-weighted Spherical CNNs - Adaptation of
*Spin-Weighted Spherical CNNs*. - VDVAE - Adaptation of
*Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images*, original code at openai/vdvae. - MUSIQ - Checkpoints and model inference code for the ICCV 2021 paper
*MUSIQ: Multi-scale Image Quality Transformer* - AQuaDem - Official implementation of
*Continuous Control with Action Quantization from Demonstrations*. - Combiner - Official implementation of
*Combiner: Full Attention Transformer with Sparse Computation Cost*. - Dreamfields - Official implementation of the ICLR 2022 paper
*Progressive Distillation for Fast Sampling of Diffusion Models*. - GIFT - Official implementation of
*Gradual Domain Adaptation in the Wild:When Intermediate Distributions are Absent*. - Light Field Neural Rendering - Official implementation of
*Light Field Neural Rendering*.

### Haiku

- AlphaFold - Implementation of the inference pipeline of AlphaFold v2.0, presented in
*Highly accurate protein structure prediction with AlphaFold*. - Adversarial Robustness - Reference code for
*Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples*and*Fixing Data Augmentation to Improve Adversarial Robustness*. - Bootstrap Your Own Latent - Implementation for the paper
*Bootstrap your own latent: A new approach to self-supervised Learning*. - Gated Linear Networks - GLNs are a family of backpropagation-free neural networks.
- Glassy Dynamics - Open source implementation of the paper
*Unveiling the predictive power of static structure in glassy systems*. - MMV - Code for the models in
*Self-Supervised MultiModal Versatile Networks*. - Normalizer-Free Networks - Official Haiku implementation of
*NFNets*. - NuX - Normalizing flows with JAX.
- OGB-LSC - This repository contains DeepMind's entry to the PCQM4M-LSC (quantum chemistry) and MAG240M-LSC (academic graph) tracks of the OGB Large-Scale Challenge (OGB-LSC).
- Persistent Evolution Strategies - Code used for the paper
*Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies*. - Two Player Auction Learning - JAX implementation of the paper
*Auction learning as a two-player game*. - WikiGraphs - Baseline code to reproduce results in
*WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase*.

### Trax

- Reformer - Implementation of the Reformer (efficient transformer) architecture.

## Videos

- NeurIPS 2020: JAX Ecosystem Meetup - JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.
- Introduction to JAX - Simple neural network from scratch in JAX.
- JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas - JAX's core design, how it's powering new research, and how you can start using it.
- Bayesian Programming with JAX + NumPyro — Andy Kitchen - Introduction to Bayesian modelling using NumPyro.
- JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne - JAX intro presentation in
*Program Transformations for Machine Learning*workshop. - JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury - Presentation of TPU host access with demo.
- Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 - Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in
*Deep Implicit Layers*. - Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey - A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice.
- JAX, Flax & Transformers 🤗 - 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.

## Papers

This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

**Compiling machine learning programs via high-level tracing**. Roy Frostig, Matthew James Johnson, Chris Leary.*MLSys 2018*. - White paper describing an early version of JAX, detailing how computation is traced and compiled.**JAX, M.D.: A Framework for Differentiable Physics**. Samuel S. Schoenholz, Ekin D. Cubuk.*NeurIPS 2020*. - Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.**Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization**. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath.*arXiv 2020*. - Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.

## Tutorials and Blog Posts

- Using JAX to accelerate our research by David Budden and Matteo Hessel - Describes the state of JAX and the JAX ecosystem at DeepMind.
- Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange - Neural network building blocks from scratch with the basic JAX operators.
- Tutorial: image classification with JAX and Flax Linen by 8bitmp3 - Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.
- Plugging Into JAX by Nick Doiron - Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.
- Meta-Learning in 50 Lines of JAX by Eric Jang - Introduction to both JAX and Meta-Learning.
- Normalizing Flows in 100 Lines of JAX by Eric Jang - Concise implementation of RealNVP.
- Differentiable Path Tracing on the GPU/TPU by Eric Jang - Tutorial on implementing path tracing.
- Ensemble networks by Mat Kelcey - Ensemble nets are a method of representing an ensemble of models as one single logical model.
- Out of distribution (OOD) detection by Mat Kelcey - Implements different methods for OOD detection.
- Understanding Autodiff with JAX by Srihari Radhakrishna - Understand how autodiff works using JAX.
- From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke - Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.
- Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey - Tutorial demonstrating the infrastructure required to provide custom ops in JAX.
- Evolving Neural Networks in JAX by Robert Tjarko Lange - Explores how JAX can power the next generation of scalable neuroevolution algorithms.
- Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz - Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies.
- Deterministic ADVI in JAX by Martin Ingram - Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.
- Evolved channel selection by Mat Kelcey - Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.
- Introduction to JAX by Kevin Murphy - Colab that introduces various aspects of the language and applies them to simple ML problems.
- Writing an MCMC sampler in JAX by Jeremie Coullon - Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.
- How to add a progress bar to JAX scans and loops by Jeremie Coullon - Tutorial on how to add a progress bar to compiled loops in JAX using the
`host_callback`

module. - Get started with JAX by Aleksa Gordić - A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.

## Community

## Contributing

Contributions welcome! Read the contribution guidelines first.

Contribute to this list: https://github.com/n2cholas/awesome-jax