We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Fused Elementwise Ops
Why this matters
XLA — the compiler behind JAX — automatically fuses chains of
elementwise operations into a single GPU/TPU kernel. Writing
((x * 2) + 1) ** 2 as one expression generates the same fused kernel
as manually splitting it into three separate operations, but the code is
more readable. Understanding fusion helps you write clean code and trust
the compiler to optimise it.
# These are equivalent in output AND in compiled performance:
# One-liner (XLA fuses automatically under jit):
result = ((x * 2.0) + 1.0) ** 2
# Separate (also fused, but verbose):
tmp1 = x * 2.0
tmp2 = tmp1 + 1.0
result = tmp2 ** 2
Worked mini-example
import jax.numpy as jnp
x = jnp.array([0.0, 1.0])
print(((x * 2.0) + 1.0) ** 2) # [1. 9.]
# x=0: ((0*2)+1)^2 = 1
# x=1: ((1*2)+1)^2 = 9
Common pitfalls
- Trying to manually fuse: XLA does it for you. Don’t reach for custom CUDA kernels for simple elementwise chains.
-
Fusion only fires under
jit: withoutjax.jit, operations execute eagerly one at a time. The performance benefit of fusion appears at JIT-compiled call sites. -
Integer vs float literals: prefer
2.0and1.0over2and1to avoid unexpected integer dtype promotion in edge cases.
Problem
Implement fused_chain(x) that computes ((x * 2) + 1) ** 2 as a single
expression and returns the result.
-
x: 1-D JAX array.
Returns: 1-D array of the same shape as x.
Examples (not from the test set):
-
fused_chain(jnp.array([0.0, 1.0]))→[1., 9.]
Hints
jax
fusion
xla
Sign in to attempt this problem and view the solution.