medium primitives

jit with donate_argnums

Why this matters

Every jax.jit call normally allocates a fresh output buffer. For a large model parameter array that gets updated on every training step, that means one allocation per step — wasteful when the old buffer is never used again.

donate_argnums=(k,) signals to XLA: “the caller will never read argument k after this call; you may reuse its buffer for the output.” XLA can then do in-place updates with zero extra allocation. This is a critical optimisation for large-model training where parameter arrays are GBs in size.

The contract is strict: after a donated call, the Python object bound to that argument is invalidated. Any attempt to read it raises a RuntimeError. Design your training loop to never retain a reference to the donated argument.

Worked mini-example

import jax
import jax.numpy as jnp

params = jnp.ones((4,))
grads  = jnp.full((4,), 0.01)

update = jax.jit(lambda p, g: p - g, donate_argnums=(0,))

params = update(params, grads)
# params now holds [0.99, 0.99, 0.99, 0.99]
# the old params buffer has been donated — do NOT read it again.

Common pitfalls

  • Invalidated after the call: reading the donated argument after the call raises RuntimeError: Array has been deleted. Never keep a reference to donated tensors.
  • Only donate if you own the sole reference: if anything else holds a reference to the array, donating it creates a dangling pointer (JAX will warn or raise).
  • Don’t donate if you need the old value: gradient accumulation loops that read-then-update must keep a copy; donate only when the old value is truly discarded.
  • Multiple outputs: donate_argnums reuses the donated buffer for the first output by default; with multiple outputs the mapping may differ.

Problem

Implement update_with_donate(state, delta) that returns state + delta. Compile with jax.jit(..., donate_argnums=(0,)) so XLA may reuse the state buffer for the output.

  • state: 1-D JAX array.
  • delta: 1-D JAX array, same shape as state.

Returns: 1-D array state + delta, same shape.

Example (not from the test set):

  • update_with_donate(jnp.array([1.0, 2.0]), jnp.array([0.5, 0.5])) returns [1.5, 2.5].

Hints

jax jit donate-argnums

Sign in to attempt this problem and view the solution.