hard primitives

vmap over lax.cond

Why this matters

lax.cond handles a scalar predicate — one condition, one branch selected. But ML pipelines are batched: you often need a different branch per example. The idiomatic JAX solution is to wrap the per-element cond logic in a function and vmap over the batch dimension.

This composition is subtly different from jnp.where, which evaluates both branches for every element and then selects. vmap + lax.cond traces both branches once but dispatches each element independently — the correct semantics when branches have side effects or when one branch is expensive.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

def threshold_each(xs, ts):
    """Return 1.0 where xs[i] > ts[i], else 0.0."""
    def single(x, t):
        return lax.cond(x > t, lambda: 1.0, lambda: 0.0)
    return jax.vmap(single, in_axes=(0, 0))(xs, ts)

threshold_each(jnp.array([3.0, 1.0, 2.0]),
               jnp.array([2.0, 2.0, 2.0]))  # → [1., 0., 0.]

in_axes=(0, 0) tells vmap to map over the leading axis of both arguments simultaneously.

Common pitfalls

  • Trying to vmap directly over lax.cond. vmap(lax.cond, ...) doesn’t work — vmap needs a function to trace, not a primitive call. Always wrap the cond in a helper function first.
  • Closures in the batch loop. Inside the vmapped function, x and t are per-row tracers, not full arrays. jnp.sum(x) correctly reduces the row, not the entire batch.
  • Wrong in_axes. If you forget in_axes=(0, 0), JAX will try to broadcast one argument across the whole batch and silently give wrong shapes or errors.

Problem

Implement vmap_classify_batch(x_batch, thresholds) that classifies each row of x_batch independently: output i is 1.0 if sum(x_batch[i]) > thresholds[i], else -1.0.

  • x_batch: 2-D JAX array of shape (N, d).
  • thresholds: 1-D JAX array of shape (N,).
  • Returns: 1-D array of shape (N,).

Hints

jax vmap lax-cond composition

Sign in to attempt this problem and view the solution.