We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Gradient over Pytree Params
Why this matters
jax.grad works on pytree-valued inputs. If your params are a nested
dict like {"layer1": {"w": ..., "b": ...}, "layer2": {...}}, calling
jax.grad(loss_fn)(params) returns gradients in the same nested-dict
shape. This is why ML frameworks built on JAX (Optax, Flax, Equinox) can
do params = optimizer.update(params, grads) without any flatten/unflatten
boilerplate — params and grads line up structurally.
In real JAX code, the autodiff loop is:
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
Notice grads is a pytree of the same structure as params. Optax/Flax
never see “tensors” — they see pytrees of tensors.
This problem flattens params for testability, but the underlying autodiff
machinery treats flat_params as a 1-element pytree containing one array,
and the gradient comes back as a 1-element pytree (i.e., a single array of
the same shape). The conceptual generalization to nested params is automatic.
Worked mini-example
import jax, jax.numpy as jnp
params = {"w": jnp.array([1., 2.]), "b": jnp.array([0.5])}
x = jnp.array([[1.0, -1.0]])
def loss(p, x):
y_hat = x @ p["w"] + p["b"]
return jnp.sum(y_hat ** 2)
grads = jax.grad(loss)(params, x)
# grads = {"w": <gradient w.r.t. w>, "b": <gradient w.r.t. b>}
# Same nested-dict shape as `params`.
Common pitfalls
-
jax.grad‘sargnumsselects WHICH pytree-arg to differentiate w.r.t. — defaults to 0. -
Loss must be scalar: same as earlier grad problems — use
jnp.sum,jnp.mean, etc. -
Slicing inside
loss_fn(p[:d],p[d:d+1]) is fine — it’s a JAX trace operation, not a Python operation on a tracer. -
dfromx.shape[1]: shape access is static, sodis an int at trace time. No issues there. -
Don’t try to differentiate
x: the test only asks for grad w.r.t. params (argnums=0). Passingargnums=(0,1)returns a tuple of grads instead of a single array.
Problem
Implement grad_over_flat_params(flat_params, x).
flat_params is a 1-D array of shape (d + 1,) — the first d entries
are weights and the last entry is a bias. x is a 2-D array of shape
(N, d). The model is y_hat = x @ weights + bias, and the loss is
sum(y_hat ** 2) (no targets — purely for testing the grad mechanism).
Return the gradient of the loss w.r.t. flat_params — a 1-D array of
the same shape (d + 1,).
Two illustrative examples (not from the test set):
-
flat_params = [1.0, 0.0, 0.0],x = [[2.0, 3.0]]: d=2, w=[1,0], b=[0]. y_hat = 12+03+0 = 2. loss = 4. dL/dw0 = 222 = 8, dL/dw1 = 223 = 12, dL/db = 2*2 = 4. → grad =[8.0, 12.0, 4.0] -
flat_params = [0.0, 0.0, 1.0],x = [[1.0, 0.0], [0.0, 1.0]]: d=2, w=[0,0], b=[1]. y_hat = [0+1, 0+1] = [1, 1]. loss = 2. dL/dw0 = 211+210 = 2, dL/dw1 = 210+211 = 2, dL/db = 21+21 = 4. → grad =[2.0, 2.0, 4.0]
Hints
Sign in to attempt this problem and view the solution.