medium primitives

jit + grad Composition

Why this matters

JAX transformations COMPOSE. jit(grad(f)) is just as legal as jit(f) or grad(f). The order matters semantically: jit(grad(f)) traces grad(f) once and compiles the resulting program. After the first call, subsequent calls reuse the compiled artifact — orders of magnitude faster than re-running the autodiff machinery in Python every call.

This composition is the single biggest reason JAX scales: every primitive (jit, grad, vmap, pmap, scan, …) is just a function transformation, and they all stack. ML training loops are typically jit(grad(loss_fn)) (or jit(value_and_grad(...))) to get a single compiled forward+backward.

Worked mini-example

import jax, jax.numpy as jnp

def loss(w, x):
    return jnp.sum((w * x) ** 2)

# Three equivalent uses:
grad_only = jax.grad(loss)
jit_grad = jax.jit(jax.grad(loss))
value_grad_jit = jax.jit(jax.value_and_grad(loss))

# First call traces and compiles:
jit_grad(2.0, jnp.array([1.0, 3.0]))   # → 40.0  (compiled now)
jit_grad(2.0, jnp.array([1.0, 3.0]))   # → 40.0  (uses cache)

Common pitfalls

  • Wrapping the result, not the function: jax.jit(jax.grad(loss_fn(w, b, x))) is wrong — you’re wrapping the scalar result of calling grad, not the grad function itself. Correct: jax.jit(jax.grad(loss_fn, argnums=0)).
  • Order matters: jit(grad(f)) and grad(jit(f)) are both legal but mean different things. The first jits the gradient computation itself; the second jits the forward pass and then differentiates through it. Prefer jit(grad(f)) for compiled training.
  • Re-tracing on shape changes: like any jitted function, calls with different input shapes trigger re-traces. Keep shapes consistent when possible.
  • Re-defining inside a hot loop: each new lambda or closure is a new trace target. Define your jit(grad(f)) once outside the loop; don’t recreate it each iteration.

Problem

Implement jit_grad_quadratic(w, b, x) that computes the gradient of f(w, b) = sum_i (w·x_i + b)² with respect to w only, using jax.jit composed with jax.grad.

The result is numerically identical to grad_of_quadratic from the previous task. The lesson is the composition pattern: wrap jax.grad with jax.jit so the gradient computation is traced once and compiled.

Two illustrative examples (not from the test set):

  • w=0.0, b=1.0, x=[1.0, 2.0]: d/dw sum((0x+1)^2) = sum(2(0x+1)x) = 21 + 22 = 6.0
  • w=1.0, b=2.0, x=[3.0]: d/dw (13+2)^2 = 2(3+2)*3 = 30.0

Hints

jax jit grad composition

Sign in to attempt this problem and view the solution.