We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Learn JAX
A complete path from JAX fundamentals to production model code. Each track is a self-contained chunk; taken end-to-end they cover the language, the math, the libraries, and the patterns you'll use to ship.
-
JAX Autodiff
Forward and reverse mode, custom derivatives, gradient checkpointing, per-example gradients. The autodiff toolbox.
-
JAX Foundations
Functional ML โ pure functions, vmap, pytrees. The JAX way of thinking.
-
JAX Vectorization & Control Flow
vmap, scan, while_loop, fori_loop, cond, switch โ JAX's primitives for batching and control flow.
-
JAX Stochasticity
Distributions, reparameterization, sampling techniques, MCMC, gradient estimators. Randomness as a controllable resource.
-
JAX Performance & Distributed
jit details, memory and device control, sharding APIs (Mesh, PartitionSpec, shard_map), mixed precision. The performance and parallelism toolbox.
-
JAX Conceptual Deep-Dives
Custom pytree nodes, host interop callbacks, quantization patterns, jaxpr inspection, advanced debugging. Beyond the curriculum.
-
JAX Numerical Computing
Linear algebra, decompositions, FFT, convolutions, sparse arrays, polynomial fitting, numerical stability. Scientific computing in JAX.
-
JAX Profiling & Debugging
Profiler hooks, debug primitives, fixing common tracer/concretization errors, AOT compilation. Practical-skill problems for production work.
-
Optax
Gradient transforms, optimizer chains, schedules, weight decay, EMA, masking. The production optimizer library for JAX.
-
Flax
Modules, layers from scratch, attention, transformer architectures, training loops, lifted transforms. Production model code in JAX (Linen API).
-
Flax NNX
Flax's modern object-oriented API. Stateful modules, eager-mode workflow, split/merge state, lifted transforms, training, sharding. Standalone path through the recommended API for new projects.