We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
vmap over lax.cond
Why this matters
lax.cond handles a scalar predicate — one condition, one branch
selected. But ML pipelines are batched: you often need a different branch
per example. The idiomatic JAX solution is to wrap the per-element cond
logic in a function and vmap over the batch dimension.
This composition is subtly different from jnp.where, which evaluates
both branches for every element and then selects. vmap + lax.cond
traces both branches once but dispatches each element independently — the
correct semantics when branches have side effects or when one branch is
expensive.
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
def threshold_each(xs, ts):
"""Return 1.0 where xs[i] > ts[i], else 0.0."""
def single(x, t):
return lax.cond(x > t, lambda: 1.0, lambda: 0.0)
return jax.vmap(single, in_axes=(0, 0))(xs, ts)
threshold_each(jnp.array([3.0, 1.0, 2.0]),
jnp.array([2.0, 2.0, 2.0])) # → [1., 0., 0.]
in_axes=(0, 0) tells vmap to map over the leading axis of both
arguments simultaneously.
Common pitfalls
-
Trying to vmap directly over lax.cond.
vmap(lax.cond, ...)doesn’t work — vmap needs a function to trace, not a primitive call. Always wrap the cond in a helper function first. -
Closures in the batch loop. Inside the vmapped function,
xandtare per-row tracers, not full arrays.jnp.sum(x)correctly reduces the row, not the entire batch. -
Wrong in_axes. If you forget
in_axes=(0, 0), JAX will try to broadcast one argument across the whole batch and silently give wrong shapes or errors.
Problem
Implement vmap_classify_batch(x_batch, thresholds) that classifies each
row of x_batch independently: output i is 1.0 if
sum(x_batch[i]) > thresholds[i], else -1.0.
-
x_batch: 2-D JAX array of shape(N, d). -
thresholds: 1-D JAX array of shape(N,). -
Returns: 1-D array of shape
(N,).
Hints
Sign in to attempt this problem and view the solution.