easy primitives

vjp Basics

Why this matters

jax.vjp is JAX’s reverse-mode autodiff primitive. For a function f: R^n → R^m and a cotangent vector g (same shape as the output), vjp_fn(g) computes the vector-Jacobian product gᵀ · J_f(x) in a single backward pass.

This is the engine behind jax.grad: for a scalar-output f, jax.grad(f)(x) is exactly jax.vjp(f, x)[1](1.0)[0]. Knowing the primitive lets you:

  • Apply multiple cotangents without re-running the forward pass.
  • Implement custom reverse-mode rules via jax.custom_vjp.
  • Build batched Jacobians row-by-row (next problem).

Reverse-mode is cheapest when output dimension ≪ input dimension — the classic loss-function scenario in ML.

Worked mini-example

import jax, jax.numpy as jnp

def f(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0, 3.0])
primal, vjp_fn = jax.vjp(f, x)
# primal  → 14.0       (f(x) = 1 + 4 + 9)

grad = vjp_fn(1.0)[0]
# grad    → [2., 4., 6.]   (2*x with cotangent=1.0)

scaled = vjp_fn(2.0)[0]
# scaled  → [4., 8., 12.]  (2*x scaled by cotangent=2.0)

Common pitfalls

  • vjp_fn returns a TUPLE — index [0] to get the gradient for the first (and only) input argument.
  • Cotangent shape must match OUTPUT shape — for scalar f, pass a Python float; for array f, pass an array of the same shape.
  • jax.vjp(f(x), cotangent) is wrong — wrap the FUNCTION, not its application: jax.vjp(f, x).
  • Reusing vjp_fn — you can call vjp_fn(g) many times with different cotangents without re-running the forward pass.

Problem

Implement vjp_quadratic(x, cotangent) that applies jax.vjp to f(x) = sum(x²) at x with the given cotangent.

  • x: 1-D jax array.
  • cotangent: scalar (Python float).

Returns: 1-D array same shape as x — the vector-Jacobian product cotangent · 2x.

For cotangent=1.0 this equals jax.grad(f)(x) = 2*x.

Use jax.vjp — do not compute the gradient by hand.

Hints

jax vjp reverse-mode

Sign in to attempt this problem and view the solution.