medium primitives

profiler.StepTraceAnnotation

Why this matters

When profiling a training loop, it’s not enough to see individual op timings β€” you need to see step boundaries. jax.profiler.StepTraceAnnotation marks a training step in the profiler timeline, so tools like TensorBoard can clearly delineate step N from step N+1. This makes it easy to compare per-step latency and identify slow steps.

Like jax.named_scope, the annotation has zero effect on numerical output.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import profiler as jpr

def train(params, x, step):
    with jpr.StepTraceAnnotation("train", step_num=int(step)):
        loss = jnp.sum((params * x) ** 2)
        grad = jax.grad(lambda p: jnp.sum((p * x) ** 2))(params)
        params = params - 0.01 * grad
    return loss, params

In TensorBoard, each StepTraceAnnotation block appears as a labeled step boundary, making per-step comparisons straightforward.

Common pitfalls

  • step_num must be a Python int β€” passing a JAX tracer or NumPy scalar will raise. Always cast: int(step).
  • Does not affect numerical output β€” return values are unchanged.
  • Annotations don’t persist beyond the context β€” the block covers only the ops executed inside the with statement.
  • Used in training loops, not individual ops β€” wrap your whole step body, not just one operation.

Problem

Implement annotated_step(x, step) that:

  1. Opens a jpr.StepTraceAnnotation("train", step_num=int(step)) context.
  2. Inside it, computes and returns jnp.sum(x).
  • x: 1-D JAX array.
  • step: a float (cast to int before passing as step_num).

Returns: scalar β€” sum(x).

Hints

jax profiler step-trace

Sign in to attempt this problem and view the solution.