easy primitives

Gradient with jax.grad

Why this matters

jax.grad is the workhorse of autodiff in JAX. Given a scalar-valued function, jax.grad(f) returns a new function that computes ∂f/∂(first argument). The argnums keyword (default 0) lets you choose which argument to differentiate with respect to. One hard constraint: the function must return a scalar — jax.grad won’t work on vector-valued functions (use jax.jacobian for that). This mirrors the mathematical definition: a gradient is defined for scalar-valued multivariable functions.

Worked mini-example

import jax, jax.numpy as jnp

def f(w, x):
    return jnp.sum(w * x)

df_dw = jax.grad(f)                         # gradient wrt w (first arg)
df_dw(2.0, jnp.array([1.0, 3.0]))           # → 4.0  (1 + 3)

# Gradient wrt a different arg:
df_dx = jax.grad(f, argnums=1)
df_dx(2.0, jnp.array([1.0, 3.0]))           # → [2.0, 2.0]

# Gradient wrt multiple args (returns a tuple of gradients):
df_both = jax.grad(f, argnums=(0, 1))
df_both(2.0, jnp.array([1.0, 3.0]))         # → (4.0, [2.0, 2.0])

Common pitfalls

  • Non-scalar output: jax.grad raises TypeError if the function returns anything other than a scalar. Fix: wrap in jnp.sum or jnp.mean.
  • Forgetting argnums: defaults to 0. If you need the gradient w.r.t. a later argument, set argnums=k.
  • Differentiating w.r.t. integers: jax.grad requires float-typed inputs for the differentiated argument. Cast int inputs to float first.
  • Calling instead of wrapping: jax.grad(loss_fn(w, b, x)) is wrong — you pass the function, not its return value.

Problem

Implement grad_of_quadratic(w, b, x) that computes the gradient of f(w, b) = sum_i (w·x_i + b)² with respect to w only.

The closed-form answer is d/dw [∑ᵢ (w·xᵢ + b)²] = ∑ᵢ 2·(w·xᵢ + b)·xᵢ, but you should use jax.grad to compute it automatically.

Two illustrative examples (not from the test set):

  • w=0.0, b=1.0, x=[1.0, 2.0]: d/dw sum((0x+1)^2) = sum(2(0x+1)x) = 21 + 22 = 6.0
  • w=1.0, b=2.0, x=[3.0]: d/dw (13+2)^2 = 2(3+2)*3 = 30.0

Hints

jax grad autodiff

Sign in to attempt this problem and view the solution.