We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Gradient with jax.grad
Why this matters
jax.grad is the workhorse of autodiff in JAX. Given a scalar-valued
function, jax.grad(f) returns a new function that computes
∂f/∂(first argument). The argnums keyword (default 0) lets you choose
which argument to differentiate with respect to. One hard constraint: the
function must return a scalar — jax.grad won’t work on vector-valued
functions (use jax.jacobian for that). This mirrors the mathematical
definition: a gradient is defined for scalar-valued multivariable functions.
Worked mini-example
import jax, jax.numpy as jnp
def f(w, x):
return jnp.sum(w * x)
df_dw = jax.grad(f) # gradient wrt w (first arg)
df_dw(2.0, jnp.array([1.0, 3.0])) # → 4.0 (1 + 3)
# Gradient wrt a different arg:
df_dx = jax.grad(f, argnums=1)
df_dx(2.0, jnp.array([1.0, 3.0])) # → [2.0, 2.0]
# Gradient wrt multiple args (returns a tuple of gradients):
df_both = jax.grad(f, argnums=(0, 1))
df_both(2.0, jnp.array([1.0, 3.0])) # → (4.0, [2.0, 2.0])
Common pitfalls
-
Non-scalar output:
jax.gradraisesTypeErrorif the function returns anything other than a scalar. Fix: wrap injnp.sumorjnp.mean. -
Forgetting
argnums: defaults to0. If you need the gradient w.r.t. a later argument, setargnums=k. -
Differentiating w.r.t. integers:
jax.gradrequires float-typed inputs for the differentiated argument. Cast int inputs to float first. -
Calling instead of wrapping:
jax.grad(loss_fn(w, b, x))is wrong — you pass the function, not its return value.
Problem
Implement grad_of_quadratic(w, b, x) that computes the gradient of
f(w, b) = sum_i (w·x_i + b)² with respect to w only.
The closed-form answer is d/dw [∑ᵢ (w·xᵢ + b)²] = ∑ᵢ 2·(w·xᵢ + b)·xᵢ,
but you should use jax.grad to compute it automatically.
Two illustrative examples (not from the test set):
-
w=0.0, b=1.0, x=[1.0, 2.0]: d/dw sum((0x+1)^2) = sum(2(0x+1)x) = 21 + 22 = 6.0 -
w=1.0, b=2.0, x=[3.0]: d/dw (13+2)^2 = 2(3+2)*3 = 30.0
Hints
Sign in to attempt this problem and view the solution.