medium primitives

Gradient over Pytree Params

Why this matters

jax.grad works on pytree-valued inputs. If your params are a nested dict like {"layer1": {"w": ..., "b": ...}, "layer2": {...}}, calling jax.grad(loss_fn)(params) returns gradients in the same nested-dict shape. This is why ML frameworks built on JAX (Optax, Flax, Equinox) can do params = optimizer.update(params, grads) without any flatten/unflatten boilerplate — params and grads line up structurally.

In real JAX code, the autodiff loop is:

loss, grads = jax.value_and_grad(loss_fn)(params, batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

Notice grads is a pytree of the same structure as params. Optax/Flax never see “tensors” — they see pytrees of tensors.

This problem flattens params for testability, but the underlying autodiff machinery treats flat_params as a 1-element pytree containing one array, and the gradient comes back as a 1-element pytree (i.e., a single array of the same shape). The conceptual generalization to nested params is automatic.

Worked mini-example

import jax, jax.numpy as jnp

params = {"w": jnp.array([1., 2.]), "b": jnp.array([0.5])}
x      = jnp.array([[1.0, -1.0]])

def loss(p, x):
    y_hat = x @ p["w"] + p["b"]
    return jnp.sum(y_hat ** 2)

grads = jax.grad(loss)(params, x)
# grads = {"w": <gradient w.r.t. w>, "b": <gradient w.r.t. b>}
# Same nested-dict shape as `params`.

Common pitfalls

  • jax.grad‘s argnums selects WHICH pytree-arg to differentiate w.r.t. — defaults to 0.
  • Loss must be scalar: same as earlier grad problems — use jnp.sum, jnp.mean, etc.
  • Slicing inside loss_fn (p[:d], p[d:d+1]) is fine — it’s a JAX trace operation, not a Python operation on a tracer.
  • d from x.shape[1]: shape access is static, so d is an int at trace time. No issues there.
  • Don’t try to differentiate x: the test only asks for grad w.r.t. params (argnums=0). Passing argnums=(0,1) returns a tuple of grads instead of a single array.

Problem

Implement grad_over_flat_params(flat_params, x).

flat_params is a 1-D array of shape (d + 1,) — the first d entries are weights and the last entry is a bias. x is a 2-D array of shape (N, d). The model is y_hat = x @ weights + bias, and the loss is sum(y_hat ** 2) (no targets — purely for testing the grad mechanism).

Return the gradient of the loss w.r.t. flat_params — a 1-D array of the same shape (d + 1,).

Two illustrative examples (not from the test set):

  • flat_params = [1.0, 0.0, 0.0], x = [[2.0, 3.0]]: d=2, w=[1,0], b=[0]. y_hat = 12+03+0 = 2. loss = 4. dL/dw0 = 222 = 8, dL/dw1 = 223 = 12, dL/db = 2*2 = 4. → grad = [8.0, 12.0, 4.0]

  • flat_params = [0.0, 0.0, 1.0], x = [[1.0, 0.0], [0.0, 1.0]]: d=2, w=[0,0], b=[1]. y_hat = [0+1, 0+1] = [1, 1]. loss = 2. dL/dw0 = 211+210 = 2, dL/dw1 = 210+211 = 2, dL/db = 21+21 = 4. → grad = [2.0, 2.0, 4.0]

Hints

jax grad pytree

Sign in to attempt this problem and view the solution.