We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
vmap-cond vs where: cost tradeoff
Why this matters
When you need per-element branching over a batch, JAX offers two idioms with identical numerical results but different cost profiles:
-
vmap(lax.cond)โ traces both branches once but dispatches each row independently. At low batch sizes with expensive branches, this avoids computing the unused path per element (in theory). Under XLA the hardware reality is more nuanced โ both branches are compiled, and predication may still run both. -
jnp.whereโ always evaluates both branches for every element, then selects. Vectorizes cleanly, no per-element control flow overhead.
Understanding this tradeoff lets you choose the right primitive for the
problem: vmap + cond when branches have side effects or you want explicit
short-circuit semantics; jnp.where for simple element-wise selection where
both branches are cheap.
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
def two_ways(xs, t):
"""xs > t โ xs * 10, else xs + 1. Two implementations."""
# Way (a): vmap(cond)
def per_elem(x):
return lax.cond(x > t, lambda: x * 10.0, lambda: x + 1.0)
result_a = jax.vmap(per_elem)(xs)
# Way (b): where
result_b = jnp.where(xs > t, xs * 10.0, xs + 1.0)
return result_a, result_b # numerically identical
two_ways(jnp.array([0.5, 2.0, 3.0]), 1.0)
# โ ([1.5, 20., 30.], [1.5, 20., 30.])
Common pitfalls
-
Forgetting to broadcast the condition in
jnp.where. Ifx_batchhas shape(N, d)and your per-row condition has shape(N,), you must expand it:cond[:, None]. Without this,jnp.wherewill error or silently misbroadcast. - Expecting different results. Both methods produce the same numbers; the difference is purely about compiled performance, not correctness.
-
vmap of cond on 2-D input.
per_row_condreceives a 1-D row; thelambda: row * 2.0must close over that per-rowrow, not the full batch.
Problem
Implement vmap_branch_two_ways(x_batch, threshold) that applies the branch
logic row * 2 if sum(row) > threshold else row + 1 per row, via both
vmap(lax.cond) and jnp.where.
-
x_batch: 2-D JAX array of shape(N, d). -
threshold: scalar. -
Returns: 3-D array of shape
(2, N, d)โ[result_via_cond, result_via_where]. Both slices are numerically identical.
Hints
jax
vmap
lax-cond
where
performance
Sign in to attempt this problem and view the solution.