We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jax.linearize Primitive
Why this matters
jax.linearize(f, *primals) returns (primal_out, lin_fn) where lin_fn
is a linear function of the tangent. Applying lin_fn(v) computes the
directional derivative Jf(x) · v at the previously traced point x.
This differs from jax.jvp which fuses the primal and tangent computation
into a single call. linearize is useful when you need to apply the same
Jacobian to many different tangents: the primal and Jacobian trace are
cached once, and each lin_fn(v_i) call is cheap.
Real-world uses:
-
Sensitivity analysis — query
lin_fnwith many perturbation directions. - Taylor-mode AD — build higher-order approximations around a fixed point.
- Linearization layers in physics simulations or solvers.
Worked mini-example
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
x = jnp.array([1.0, 2.0])
v1 = jnp.array([1.0, 0.0])
v2 = jnp.array([0.0, 1.0])
primal, lin = jax.linearize(f, x)
# primal → 5.0 (1² + 2²)
# lin(v1) → 2.0 (∂f/∂x₀ = 2*1 = 2)
# lin(v2) → 4.0 (∂f/∂x₁ = 2*2 = 4)
Common pitfalls
-
lin_fnis a function — you must call it:lin_fn(v), not justlin_fn. -
jax.jvpvsjax.linearize—jvpreturns(primal, tangent)directly;linearizereturns(primal, lin_fn). They produce the same tangent value butlinearizecaches the Jacobian trace. -
Shape must match — calling
lin_fnwith a tangent of different shape fromxwill raise a shape error.
Problem
Implement linearize_at_point(x, v) that:
-
Defines
f(x) = sum(x²). -
Uses
jax.linearize(f, x)to getprimal_outandlin_fn. -
Applies
lin_fn(v)to get the directional derivative. -
Returns a 1-D array
[primal_out, tangent_out].
-
x: 1-D jax array. -
v: 1-D jax array same shape asx.
Returns: 1-D array shape (2,) — [sum(x²), 2·x·v].
Hints
Sign in to attempt this problem and view the solution.