medium primitives

jax.linearize Primitive

Why this matters

jax.linearize(f, *primals) returns (primal_out, lin_fn) where lin_fn is a linear function of the tangent. Applying lin_fn(v) computes the directional derivative Jf(x) · v at the previously traced point x.

This differs from jax.jvp which fuses the primal and tangent computation into a single call. linearize is useful when you need to apply the same Jacobian to many different tangents: the primal and Jacobian trace are cached once, and each lin_fn(v_i) call is cheap.

Real-world uses:

  • Sensitivity analysis — query lin_fn with many perturbation directions.
  • Taylor-mode AD — build higher-order approximations around a fixed point.
  • Linearization layers in physics simulations or solvers.

Worked mini-example

import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0])
v1 = jnp.array([1.0, 0.0])
v2 = jnp.array([0.0, 1.0])

primal, lin = jax.linearize(f, x)
# primal → 5.0  (1² + 2²)
# lin(v1) → 2.0   (∂f/∂x₀ = 2*1 = 2)
# lin(v2) → 4.0   (∂f/∂x₁ = 2*2 = 4)

Common pitfalls

  • lin_fn is a function — you must call it: lin_fn(v), not just lin_fn.
  • jax.jvp vs jax.linearizejvp returns (primal, tangent) directly; linearize returns (primal, lin_fn). They produce the same tangent value but linearize caches the Jacobian trace.
  • Shape must match — calling lin_fn with a tangent of different shape from x will raise a shape error.

Problem

Implement linearize_at_point(x, v) that:

  1. Defines f(x) = sum(x²).
  2. Uses jax.linearize(f, x) to get primal_out and lin_fn.
  3. Applies lin_fn(v) to get the directional derivative.
  4. Returns a 1-D array [primal_out, tangent_out].
  • x: 1-D jax array.
  • v: 1-D jax array same shape as x.

Returns: 1-D array shape (2,)[sum(x²), 2·x·v].

Hints

jax linearize jvp

Sign in to attempt this problem and view the solution.