We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jit + grad Composition
Why this matters
JAX transformations COMPOSE. jit(grad(f)) is just as legal as jit(f) or
grad(f). The order matters semantically: jit(grad(f)) traces grad(f) once
and compiles the resulting program. After the first call, subsequent calls reuse
the compiled artifact — orders of magnitude faster than re-running the autodiff
machinery in Python every call.
This composition is the single biggest reason JAX scales: every primitive
(jit, grad, vmap, pmap, scan, …) is just a function transformation,
and they all stack. ML training loops are typically jit(grad(loss_fn)) (or
jit(value_and_grad(...))) to get a single compiled forward+backward.
Worked mini-example
import jax, jax.numpy as jnp
def loss(w, x):
return jnp.sum((w * x) ** 2)
# Three equivalent uses:
grad_only = jax.grad(loss)
jit_grad = jax.jit(jax.grad(loss))
value_grad_jit = jax.jit(jax.value_and_grad(loss))
# First call traces and compiles:
jit_grad(2.0, jnp.array([1.0, 3.0])) # → 40.0 (compiled now)
jit_grad(2.0, jnp.array([1.0, 3.0])) # → 40.0 (uses cache)
Common pitfalls
-
Wrapping the result, not the function:
jax.jit(jax.grad(loss_fn(w, b, x)))is wrong — you’re wrapping the scalar result of callinggrad, not the grad function itself. Correct:jax.jit(jax.grad(loss_fn, argnums=0)). -
Order matters:
jit(grad(f))andgrad(jit(f))are both legal but mean different things. The first jits the gradient computation itself; the second jits the forward pass and then differentiates through it. Preferjit(grad(f))for compiled training. - Re-tracing on shape changes: like any jitted function, calls with different input shapes trigger re-traces. Keep shapes consistent when possible.
-
Re-defining inside a hot loop: each new lambda or closure is a new trace target.
Define your
jit(grad(f))once outside the loop; don’t recreate it each iteration.
Problem
Implement jit_grad_quadratic(w, b, x) that computes the gradient of
f(w, b) = sum_i (w·x_i + b)² with respect to w only, using
jax.jit composed with jax.grad.
The result is numerically identical to grad_of_quadratic from the
previous task. The lesson is the composition pattern: wrap jax.grad
with jax.jit so the gradient computation is traced once and compiled.
Two illustrative examples (not from the test set):
-
w=0.0, b=1.0, x=[1.0, 2.0]: d/dw sum((0x+1)^2) = sum(2(0x+1)x) = 21 + 22 = 6.0 -
w=1.0, b=2.0, x=[3.0]: d/dw (13+2)^2 = 2(3+2)*3 = 30.0
Hints
Sign in to attempt this problem and view the solution.