We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
easy
primitives
@profiler.annotate_function
Why this matters
@jax.profiler.annotate_function is a decorator that wraps a function
with a profiler annotation — equivalent to placing
with jax.named_scope(fn.__name__): inside the function body. It is
less verbose than a with block for simple helper functions, and makes
profiler traces label each function call with the function’s own name.
Like all profiler annotations, it has zero effect on numerical output.
Worked mini-example
import jax
import jax.numpy as jnp
@jax.profiler.annotate_function
def compute_loss(params, x):
return jnp.sum((params * x) ** 2)
# In the profiler trace, this call appears as a labeled "compute_loss" span.
result = compute_loss(params, x)
This is equivalent to:
def compute_loss(params, x):
with jax.named_scope("compute_loss"):
return jnp.sum((params * x) ** 2)
Common pitfalls
-
For tighter control, use
named_scopedirectly —annotate_functionalways uses the function’s__name__, which you cannot override. -
Doesn’t compose cleanly with closures — if you need to annotate
a closure that captures variables,
named_scopeis clearer. - Applying to methods on a class — works, but the name will be the bare method name without the class prefix.
Problem
Implement annotated_call(x) that:
-
Defines an inner function
inner(x)decorated with@jax.profiler.annotate_function. -
Inside
inner, computes and returnsjnp.sum(x ** 2). -
Calls
inner(x)and returns the result.
-
x: 1-D JAX array.
Returns: scalar — sum(x^2).
Hints
jax
profiler
annotate-function
Sign in to attempt this problem and view the solution.