We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jax.value_and_grad
Why this matters
jax.value_and_grad(f) returns a function that produces both the value
and the gradient of f in a single forward+backward pass. Without it,
you’d compute f(x) once for the value and again (implicitly) inside
jax.grad(f)(x). value_and_grad reuses the same forward — saving one
pass. This is the canonical pattern for ML training loops:
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
params = update(params, grads)
Worked mini-example
import jax, jax.numpy as jnp
def f(w, x):
return jnp.sum((w * x) ** 2)
val, grad = jax.value_and_grad(f)(2.0, jnp.array([1.0, 3.0]))
# val = 2^2*1^2 + 2^2*3^2 = 4 + 36 = 40.0
# grad = d/dw sum((w*x)^2) at w=2 = 2*sum(w*x*x) = 2*2*(1+9) = 40.0
Common pitfalls
-
(value, grad) = jax.grad(f)(args)— wrong:jax.gradreturns ONLY the gradient. Usevalue_and_gradto get both. -
Confusion about which is which: it’s
(value, grad)— the function value comes first. -
has_aux=True: lets the inner function return(loss, aux)andvalue_and_gradreturns((loss, aux), grad). Useful for logging metrics during training without polluting the gradient signal. -
Differentiation argnums: same as
jax.grad— theargnums=kparameter selects which arg to differentiate w.r.t.
Problem
Implement loss_and_grad(w, b, x) that computes both the loss
f(w, b) = sum_i (w·x_i + b)² and its gradient w.r.t. w only,
in a single pass using jax.value_and_grad.
Return a 1-D array of shape (2,): [loss, grad_w].
Two illustrative examples (not from the test set):
-
w=0.0, b=1.0, x=[1.0, 2.0]: loss = sum((0x+1)^2) = 1+1 = 2.0; grad = sum(2(0x+1)x) = 21 + 22 = 6.0 → [2.0, 6.0] -
w=1.0, b=2.0, x=[3.0]: loss = (13+2)^2 = 25.0; grad = 2(3+2)*3 = 30.0 → [25.0, 30.0]
Hints
Sign in to attempt this problem and view the solution.