easy primitives

@profiler.annotate_function

Why this matters

@jax.profiler.annotate_function is a decorator that wraps a function with a profiler annotation — equivalent to placing with jax.named_scope(fn.__name__): inside the function body. It is less verbose than a with block for simple helper functions, and makes profiler traces label each function call with the function’s own name.

Like all profiler annotations, it has zero effect on numerical output.

Worked mini-example

import jax
import jax.numpy as jnp

@jax.profiler.annotate_function
def compute_loss(params, x):
    return jnp.sum((params * x) ** 2)

# In the profiler trace, this call appears as a labeled "compute_loss" span.
result = compute_loss(params, x)

This is equivalent to:

def compute_loss(params, x):
    with jax.named_scope("compute_loss"):
        return jnp.sum((params * x) ** 2)

Common pitfalls

  • For tighter control, use named_scope directlyannotate_function always uses the function’s __name__, which you cannot override.
  • Doesn’t compose cleanly with closures — if you need to annotate a closure that captures variables, named_scope is clearer.
  • Applying to methods on a class — works, but the name will be the bare method name without the class prefix.

Problem

Implement annotated_call(x) that:

  1. Defines an inner function inner(x) decorated with @jax.profiler.annotate_function.
  2. Inside inner, computes and returns jnp.sum(x ** 2).
  3. Calls inner(x) and returns the result.
  • x: 1-D JAX array.

Returns: scalar — sum(x^2).

Hints

jax profiler annotate-function

Sign in to attempt this problem and view the solution.