We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Debug Callback
Why this matters
jax.debug.print (pos 95) is great for ad-hoc inspection, but
real instrumentation often needs more: write to a log file, send
metrics to W&B, dump a histogram to disk, push a number to a
Prometheus gauge. For that, you need to call arbitrary host-side
Python from inside a jitted forward.
jax.debug.callback is the primitive: it accepts a Python callable
and any number of JAX arrays, and at runtime it materializes those
arrays and calls the function with them.
def log_norm(value):
with open('/tmp/norms.log', 'a') as f:
f.write(f"{float(value)}\n")
jax.debug.callback(log_norm, jnp.linalg.norm(h))
Inside jit, the callback is queued. At runtime, JAX materializes
jnp.linalg.norm(h) to a concrete array, hands it to log_norm,
which writes the file. Side effects on the host; the array
computation never leaves the device until the callback fires.
API
jax.debug.callback(callable, *args, ordered=False)
-
callable: any Python function. Receives concrete numpy arrays (or scalars). Returns are ignored. -
*args: JAX arrays. They get materialized and passed in order. -
ordered: same flag asjax.debug.print. False (the default) lets XLA reorder; True forces source-order at a small perf cost.
Compared to jax.debug.print:
| feature |
print |
callback |
|---|---|---|
| format string? | yes | no — it’s just args |
| side effect | stdout | whatever you write |
| host-side Python | implicit (printf) | explicit (your fn) |
| typical use | quick inspection | metrics/logging |
For pure logging-to-stdout, print is shorter. For “I want to call
wandb.log({...}) from inside a jitted training step”, callback
is the only option.
A no-op callback
For this problem, we use a callback whose body is lambda v: None
— a function that does literally nothing. It exercises the
machinery (the value gets materialized, the host-side function
gets called) without producing any external side effect. The model
output is unchanged.
In real code you’d swap lambda v: None for lambda v: log.info(...)
or lambda v: my_metrics.append(float(v)). The shape of the call
is the same.
Why this matters in nnx
Callbacks compose with everything: jit, vmap, grad, scan.
Inside an nnx __call__ they sit alongside ordinary arithmetic.
A typical training-instrumentation pattern is:
class InstrumentedBlock(nnx.Module):
def __call__(self, x):
h = self.attn(x)
jax.debug.callback(_my_metrics_fn, jnp.linalg.norm(h))
return self.ffn(h)
Toggle _my_metrics_fn between a real logger and a no-op for
train/eval modes. Or stash data into a closure-captured Python
list — the callback runs synchronously per step (with
ordered=True) so accumulating data is safe.
Common pitfalls
-
Returning a value from the callback. Ignored. If you need
the output back in the JAX computation, use
jax.experimental.host_callbackorjax.pure_callback— different primitives that DO return values.jax.debug.callbackis fire-and-forget. -
Passing a Python int as an arg. It gets traced into JAX as
a constant. For dynamic ints, wrap as
jnp.array(...). - Heavy work inside the callback. It runs on the host between steps; expensive callbacks serialize device execution. Keep them light (write a number, not analyze a tensor).
-
Forgetting
ordered=Truewhen order matters. Two callbacks may run in surprising order otherwise.
Problem
Write debug_callback_passthrough(seed, x, features):
-
Define a
CallbackModel(nnx.Module)with twonnx.Linearlayers. First mapsx.shape[-1] -> features, second mapsfeatures -> features. -
In
__call__(x): computeh = self.l1(x). Calljax.debug.callback(lambda v: None, jnp.linalg.norm(h))— a no-op host-side callback that exercises the machinery. Returnself.l2(h). -
The function builds the model and returns
model(x).
The output is identical to a model without the callback (the callback has no return effect) — but the runtime path went through host-side Python.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: 1-D (features,) — the second layer’s output.
Hints
Sign in to attempt this problem and view the solution.