medium primitives

NNX Debug Callback

Why this matters

jax.debug.print (pos 95) is great for ad-hoc inspection, but real instrumentation often needs more: write to a log file, send metrics to W&B, dump a histogram to disk, push a number to a Prometheus gauge. For that, you need to call arbitrary host-side Python from inside a jitted forward.

jax.debug.callback is the primitive: it accepts a Python callable and any number of JAX arrays, and at runtime it materializes those arrays and calls the function with them.

def log_norm(value):
    with open('/tmp/norms.log', 'a') as f:
        f.write(f"{float(value)}\n")

jax.debug.callback(log_norm, jnp.linalg.norm(h))

Inside jit, the callback is queued. At runtime, JAX materializes jnp.linalg.norm(h) to a concrete array, hands it to log_norm, which writes the file. Side effects on the host; the array computation never leaves the device until the callback fires.

API

jax.debug.callback(callable, *args, ordered=False)
  • callable: any Python function. Receives concrete numpy arrays (or scalars). Returns are ignored.
  • *args: JAX arrays. They get materialized and passed in order.
  • ordered: same flag as jax.debug.print. False (the default) lets XLA reorder; True forces source-order at a small perf cost.

Compared to jax.debug.print:

feature print callback
format string? yes no — it’s just args
side effect stdout whatever you write
host-side Python implicit (printf) explicit (your fn)
typical use quick inspection metrics/logging

For pure logging-to-stdout, print is shorter. For “I want to call wandb.log({...}) from inside a jitted training step”, callback is the only option.

A no-op callback

For this problem, we use a callback whose body is lambda v: None — a function that does literally nothing. It exercises the machinery (the value gets materialized, the host-side function gets called) without producing any external side effect. The model output is unchanged.

In real code you’d swap lambda v: None for lambda v: log.info(...) or lambda v: my_metrics.append(float(v)). The shape of the call is the same.

Why this matters in nnx

Callbacks compose with everything: jit, vmap, grad, scan. Inside an nnx __call__ they sit alongside ordinary arithmetic. A typical training-instrumentation pattern is:

class InstrumentedBlock(nnx.Module):
    def __call__(self, x):
        h = self.attn(x)
        jax.debug.callback(_my_metrics_fn, jnp.linalg.norm(h))
        return self.ffn(h)

Toggle _my_metrics_fn between a real logger and a no-op for train/eval modes. Or stash data into a closure-captured Python list — the callback runs synchronously per step (with ordered=True) so accumulating data is safe.

Common pitfalls

  • Returning a value from the callback. Ignored. If you need the output back in the JAX computation, use jax.experimental.host_callback or jax.pure_callback — different primitives that DO return values. jax.debug.callback is fire-and-forget.
  • Passing a Python int as an arg. It gets traced into JAX as a constant. For dynamic ints, wrap as jnp.array(...).
  • Heavy work inside the callback. It runs on the host between steps; expensive callbacks serialize device execution. Keep them light (write a number, not analyze a tensor).
  • Forgetting ordered=True when order matters. Two callbacks may run in surprising order otherwise.

Problem

Write debug_callback_passthrough(seed, x, features):

  1. Define a CallbackModel(nnx.Module) with two nnx.Linear layers. First maps x.shape[-1] -> features, second maps features -> features.
  2. In __call__(x): compute h = self.l1(x). Call jax.debug.callback(lambda v: None, jnp.linalg.norm(h)) — a no-op host-side callback that exercises the machinery. Return self.l2(h).
  3. The function builds the model and returns model(x).

The output is identical to a model without the callback (the callback has no return effect) — but the runtime path went through host-side Python.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: 1-D (features,) — the second layer’s output.

Hints

flax nnx debug jax-debug-callback

Sign in to attempt this problem and view the solution.