We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jit with donate_argnums
Why this matters
Every jax.jit call normally allocates a fresh output buffer. For a large
model parameter array that gets updated on every training step, that means
one allocation per step — wasteful when the old buffer is never used again.
donate_argnums=(k,) signals to XLA: “the caller will never read argument
k after this call; you may reuse its buffer for the output.” XLA can then
do in-place updates with zero extra allocation. This is a critical
optimisation for large-model training where parameter arrays are GBs in
size.
The contract is strict: after a donated call, the Python object bound to
that argument is invalidated. Any attempt to read it raises a
RuntimeError. Design your training loop to never retain a reference to
the donated argument.
Worked mini-example
import jax
import jax.numpy as jnp
params = jnp.ones((4,))
grads = jnp.full((4,), 0.01)
update = jax.jit(lambda p, g: p - g, donate_argnums=(0,))
params = update(params, grads)
# params now holds [0.99, 0.99, 0.99, 0.99]
# the old params buffer has been donated — do NOT read it again.
Common pitfalls
-
Invalidated after the call: reading the donated argument after the
call raises
RuntimeError: Array has been deleted. Never keep a reference to donated tensors. - Only donate if you own the sole reference: if anything else holds a reference to the array, donating it creates a dangling pointer (JAX will warn or raise).
- Don’t donate if you need the old value: gradient accumulation loops that read-then-update must keep a copy; donate only when the old value is truly discarded.
-
Multiple outputs:
donate_argnumsreuses the donated buffer for the first output by default; with multiple outputs the mapping may differ.
Problem
Implement update_with_donate(state, delta) that returns state + delta.
Compile with jax.jit(..., donate_argnums=(0,)) so XLA may reuse the
state buffer for the output.
-
state: 1-D JAX array. -
delta: 1-D JAX array, same shape asstate.
Returns: 1-D array state + delta, same shape.
Example (not from the test set):
-
update_with_donate(jnp.array([1.0, 2.0]), jnp.array([0.5, 0.5]))returns[1.5, 2.5].
Hints
Sign in to attempt this problem and view the solution.