medium primitives

Fused Elementwise Ops

Why this matters

XLA — the compiler behind JAX — automatically fuses chains of elementwise operations into a single GPU/TPU kernel. Writing ((x * 2) + 1) ** 2 as one expression generates the same fused kernel as manually splitting it into three separate operations, but the code is more readable. Understanding fusion helps you write clean code and trust the compiler to optimise it.

# These are equivalent in output AND in compiled performance:
# One-liner (XLA fuses automatically under jit):
result = ((x * 2.0) + 1.0) ** 2

# Separate (also fused, but verbose):
tmp1 = x * 2.0
tmp2 = tmp1 + 1.0
result = tmp2 ** 2

Worked mini-example

import jax.numpy as jnp

x = jnp.array([0.0, 1.0])
print(((x * 2.0) + 1.0) ** 2)  # [1. 9.]
# x=0: ((0*2)+1)^2 = 1
# x=1: ((1*2)+1)^2 = 9

Common pitfalls

  • Trying to manually fuse: XLA does it for you. Don’t reach for custom CUDA kernels for simple elementwise chains.
  • Fusion only fires under jit: without jax.jit, operations execute eagerly one at a time. The performance benefit of fusion appears at JIT-compiled call sites.
  • Integer vs float literals: prefer 2.0 and 1.0 over 2 and 1 to avoid unexpected integer dtype promotion in edge cases.

Problem

Implement fused_chain(x) that computes ((x * 2) + 1) ** 2 as a single expression and returns the result.

  • x: 1-D JAX array.

Returns: 1-D array of the same shape as x.

Examples (not from the test set):

  • fused_chain(jnp.array([0.0, 1.0]))[1., 9.]

Hints

jax fusion xla

Sign in to attempt this problem and view the solution.