We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jax.named_scope for Profiler Grouping
Why this matters
When you open a profiler trace (TensorBoard, Perfetto) for a large
training step, you see hundreds of unlabeled XLA ops — it’s hard to
tell which ops belong to the forward pass, backward pass, or optimizer
update. jax.named_scope("name") fixes this by grouping ops under a
labeled span in the trace timeline, making profiler output readable.
The key insight: named_scope has zero effect on numerical output.
It is purely informational — a structured annotation for the profiler.
Worked mini-example
import jax
import jax.numpy as jnp
def training_step(params, x):
with jax.named_scope("forward_pass"):
logits = jnp.dot(x, params)
loss = jnp.sum(logits ** 2)
with jax.named_scope("optimizer"):
grad = jax.grad(lambda p: jnp.sum(jnp.dot(x, p) ** 2))(params)
params = params - 0.01 * grad
return loss, params
In a profiler trace, XLA ops inside each scope appear under a labeled span — “forward_pass” and “optimizer” are visible in the timeline.
Common pitfalls
-
named_scopeis a context manager — you must usewith, not call it as a function directly. - It does not change numerical output — if your results differ, the bug is elsewhere.
-
Scopes are nestable — you can place
named_scopeblocks inside othernamed_scopeblocks for fine-grained labeling. - The name appears exactly as written — typos show up in the trace.
Problem
Implement named_scope_compute(x) that:
-
Opens a
jax.named_scope("forward_pass")context. -
Inside it, computes and returns
jnp.sum(x ** 2).
-
x: 1-D JAX array.
Returns: scalar — sum(x^2).
Hints
Sign in to attempt this problem and view the solution.