medium primitives

jax.named_scope for Profiler Grouping

Why this matters

When you open a profiler trace (TensorBoard, Perfetto) for a large training step, you see hundreds of unlabeled XLA ops — it’s hard to tell which ops belong to the forward pass, backward pass, or optimizer update. jax.named_scope("name") fixes this by grouping ops under a labeled span in the trace timeline, making profiler output readable.

The key insight: named_scope has zero effect on numerical output. It is purely informational — a structured annotation for the profiler.

Worked mini-example

import jax
import jax.numpy as jnp

def training_step(params, x):
    with jax.named_scope("forward_pass"):
        logits = jnp.dot(x, params)
        loss = jnp.sum(logits ** 2)

    with jax.named_scope("optimizer"):
        grad = jax.grad(lambda p: jnp.sum(jnp.dot(x, p) ** 2))(params)
        params = params - 0.01 * grad

    return loss, params

In a profiler trace, XLA ops inside each scope appear under a labeled span — “forward_pass” and “optimizer” are visible in the timeline.

Common pitfalls

  • named_scope is a context manager — you must use with, not call it as a function directly.
  • It does not change numerical output — if your results differ, the bug is elsewhere.
  • Scopes are nestable — you can place named_scope blocks inside other named_scope blocks for fine-grained labeling.
  • The name appears exactly as written — typos show up in the trace.

Problem

Implement named_scope_compute(x) that:

  1. Opens a jax.named_scope("forward_pass") context.
  2. Inside it, computes and returns jnp.sum(x ** 2).
  • x: 1-D JAX array.

Returns: scalar — sum(x^2).

Hints

jax profiler named-scope

Sign in to attempt this problem and view the solution.