Microjax: JAX in Two Classes and Six Functions

2025-07-07
Microjax: JAX in Two Classes and Six Functions

Inspired by Andrej Karpathy's Micrograd, Microjax is a library that replicates JAX functionality using only two classes and six functions. Unlike the popular PyTorch, Microjax adopts JAX's more functional programming style. This tutorial heavily borrows from Matthew J Johnson's excellent 2017 presentation on autograd, the predecessor to JAX, simplifying it and packaging it as a notebook.

Development