easy primitives

jax.value_and_grad

Why this matters

jax.value_and_grad(f) returns a function that produces both the value and the gradient of f in a single forward+backward pass. Without it, you’d compute f(x) once for the value and again (implicitly) inside jax.grad(f)(x). value_and_grad reuses the same forward — saving one pass. This is the canonical pattern for ML training loops:

loss, grads = jax.value_and_grad(loss_fn)(params, batch)
params = update(params, grads)

Worked mini-example

import jax, jax.numpy as jnp

def f(w, x):
    return jnp.sum((w * x) ** 2)

val, grad = jax.value_and_grad(f)(2.0, jnp.array([1.0, 3.0]))
# val   = 2^2*1^2 + 2^2*3^2 = 4 + 36 = 40.0
# grad  = d/dw sum((w*x)^2) at w=2 = 2*sum(w*x*x) = 2*2*(1+9) = 40.0

Common pitfalls

  • (value, grad) = jax.grad(f)(args) — wrong: jax.grad returns ONLY the gradient. Use value_and_grad to get both.
  • Confusion about which is which: it’s (value, grad) — the function value comes first.
  • has_aux=True: lets the inner function return (loss, aux) and value_and_grad returns ((loss, aux), grad). Useful for logging metrics during training without polluting the gradient signal.
  • Differentiation argnums: same as jax.grad — the argnums=k parameter selects which arg to differentiate w.r.t.

Problem

Implement loss_and_grad(w, b, x) that computes both the loss f(w, b) = sum_i (w·x_i + b)² and its gradient w.r.t. w only, in a single pass using jax.value_and_grad.

Return a 1-D array of shape (2,): [loss, grad_w].

Two illustrative examples (not from the test set):

  • w=0.0, b=1.0, x=[1.0, 2.0]: loss = sum((0x+1)^2) = 1+1 = 2.0; grad = sum(2(0x+1)x) = 21 + 22 = 6.0 → [2.0, 6.0]
  • w=1.0, b=2.0, x=[3.0]: loss = (13+2)^2 = 25.0; grad = 2(3+2)*3 = 30.0 → [25.0, 30.0]

Hints

jax value-and-grad autodiff

Sign in to attempt this problem and view the solution.