medium primitives

Multi-Branch Dispatch via lax.switch

Why this matters

lax.switch(index, branches, *operands) is the N-way generalisation of lax.cond. Where cond picks between two branches, switch selects one branch from an arbitrary list based on an integer index โ€” the functional, jit-safe analogue of a match / switch statement.

This is the right primitive whenever the number of alternatives is known at trace time but the choice is data-dependent (e.g. selecting a math operation based on a user-supplied integer). Under jit, a Python if/elif/else chain that tests a JAX integer raises ConcretizationTypeError; lax.switch resolves cleanly.

Worked mini-example

from jax import lax
import jax.numpy as jnp

def apply_op(x, idx):
    # idx=0 โ†’ negate, idx=1 โ†’ square, idx=2 โ†’ abs
    branches = [
        lambda x: -x,
        lambda x: x ** 2,
        lambda x: jnp.abs(x),
    ]
    return lax.switch(jnp.int32(idx), branches, x)

apply_op(jnp.array([-3.0, 2.0]), 1)  # โ†’ [9., 4.]

lax.switch passes *operands as arguments to the selected branch. Every branch must accept the same operands and return an output of identical shape and dtype.

Common pitfalls

  • Float index. lax.switch requires an integer index. If op_index arrives as a float scalar (common when passed through a traced pipeline), cast it: jnp.int32(op_index).
  • Out-of-range index. JAXโ€™s XLA backend clamps out-of-range indices to the nearest valid branch rather than raising โ€” it will silently execute the wrong branch. Clip if you cannot guarantee valid input.
  • Mismatched branch outputs. All branches must return the same shape and dtype. Returning jnp.sin(x) (float32) from one branch and an int from another will error at trace time.

Problem

Implement switch_op(x, op_index) that applies one of [jnp.sin, jnp.cos, jnp.exp] to x, selected by op_index.

  • x: 1-D JAX array.
  • op_index: scalar (cast to jnp.int32 internally); 0 โ†’ sin, 1 โ†’ cos, 2 โ†’ exp.
  • Returns: 1-D array, same shape as x.

Hints

jax lax-switch multi-branch

Sign in to attempt this problem and view the solution.