We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multi-Branch Dispatch via lax.switch
Why this matters
lax.switch(index, branches, *operands) is the N-way generalisation of
lax.cond. Where cond picks between two branches, switch selects one
branch from an arbitrary list based on an integer index โ the functional,
jit-safe analogue of a match / switch statement.
This is the right primitive whenever the number of alternatives is known at
trace time but the choice is data-dependent (e.g. selecting a math
operation based on a user-supplied integer). Under jit, a Python
if/elif/else chain that tests a JAX integer raises
ConcretizationTypeError; lax.switch resolves cleanly.
Worked mini-example
from jax import lax
import jax.numpy as jnp
def apply_op(x, idx):
# idx=0 โ negate, idx=1 โ square, idx=2 โ abs
branches = [
lambda x: -x,
lambda x: x ** 2,
lambda x: jnp.abs(x),
]
return lax.switch(jnp.int32(idx), branches, x)
apply_op(jnp.array([-3.0, 2.0]), 1) # โ [9., 4.]
lax.switch passes *operands as arguments to the selected branch. Every
branch must accept the same operands and return an output of identical
shape and dtype.
Common pitfalls
-
Float index.
lax.switchrequires an integer index. Ifop_indexarrives as a float scalar (common when passed through a traced pipeline), cast it:jnp.int32(op_index). - Out-of-range index. JAXโs XLA backend clamps out-of-range indices to the nearest valid branch rather than raising โ it will silently execute the wrong branch. Clip if you cannot guarantee valid input.
-
Mismatched branch outputs. All branches must return the same shape and
dtype. Returning
jnp.sin(x)(float32) from one branch and an int from another will error at trace time.
Problem
Implement switch_op(x, op_index) that applies one of
[jnp.sin, jnp.cos, jnp.exp] to x, selected by op_index.
-
x: 1-D JAX array. -
op_index: scalar (cast tojnp.int32internally);0โ sin,1โ cos,2โ exp. -
Returns: 1-D array, same shape as
x.
Hints
Sign in to attempt this problem and view the solution.