hard primitives

vmap-cond vs where: cost tradeoff

Why this matters

When you need per-element branching over a batch, JAX offers two idioms with identical numerical results but different cost profiles:

  • vmap(lax.cond) โ€” traces both branches once but dispatches each row independently. At low batch sizes with expensive branches, this avoids computing the unused path per element (in theory). Under XLA the hardware reality is more nuanced โ€” both branches are compiled, and predication may still run both.
  • jnp.where โ€” always evaluates both branches for every element, then selects. Vectorizes cleanly, no per-element control flow overhead.

Understanding this tradeoff lets you choose the right primitive for the problem: vmap + cond when branches have side effects or you want explicit short-circuit semantics; jnp.where for simple element-wise selection where both branches are cheap.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

def two_ways(xs, t):
    """xs > t โ†’ xs * 10, else xs + 1. Two implementations."""
    # Way (a): vmap(cond)
    def per_elem(x):
        return lax.cond(x > t, lambda: x * 10.0, lambda: x + 1.0)
    result_a = jax.vmap(per_elem)(xs)

    # Way (b): where
    result_b = jnp.where(xs > t, xs * 10.0, xs + 1.0)

    return result_a, result_b  # numerically identical

two_ways(jnp.array([0.5, 2.0, 3.0]), 1.0)
# โ†’ ([1.5, 20., 30.], [1.5, 20., 30.])

Common pitfalls

  • Forgetting to broadcast the condition in jnp.where. If x_batch has shape (N, d) and your per-row condition has shape (N,), you must expand it: cond[:, None]. Without this, jnp.where will error or silently misbroadcast.
  • Expecting different results. Both methods produce the same numbers; the difference is purely about compiled performance, not correctness.
  • vmap of cond on 2-D input. per_row_cond receives a 1-D row; the lambda: row * 2.0 must close over that per-row row, not the full batch.

Problem

Implement vmap_branch_two_ways(x_batch, threshold) that applies the branch logic row * 2 if sum(row) > threshold else row + 1 per row, via both vmap(lax.cond) and jnp.where.

  • x_batch: 2-D JAX array of shape (N, d).
  • threshold: scalar.
  • Returns: 3-D array of shape (2, N, d) โ€” [result_via_cond, result_via_where]. Both slices are numerically identical.

Hints

jax vmap lax-cond where performance

Sign in to attempt this problem and view the solution.