easy primitives

block_until_ready for Sync

Why this matters

JAX dispatches computation asynchronously by default. When you write result = jnp.sum(x ** 2), JAX submits the work to the accelerator and returns a future-like JAX array immediately — without waiting for the hardware to finish. This keeps the Python host busy dispatching the next operation while the device works in parallel.

The consequence: if you time a JAX computation with Python’s time.time(), you are measuring dispatch latency (microseconds), not actual compute time (milliseconds or seconds). Your benchmark looks 100× faster than reality.

.block_until_ready() forces the host to wait until the device has finished computing the array’s value. Wrap the call around whatever you are timing.

Worked mini-example

import time
import jax
import jax.numpy as jnp

x = jnp.ones((10_000,))

# Wrong — measures dispatch only:
t0 = time.perf_counter()
result = jnp.sum(x ** 2)
print(time.perf_counter() - t0)   # ≈ 50 µs (lie)

# Correct — measures real compute:
t0 = time.perf_counter()
result = jnp.sum(x ** 2)
result.block_until_ready()
print(time.perf_counter() - t0)   # ≈ 1–5 ms (truth)

Common pitfalls

  • Only matters for benchmarking: normal correctness is unaffected — JAX auto-syncs when you read a value (e.g., float(result), print).
  • Block on the right array: in a chain of ops, block on the last output, not an intermediate one.
  • Calling inside jit is a no-op: .block_until_ready() is a host-side concept; inside a traced function it has no effect.

Problem

Implement sync_compute(x) that computes jnp.sum(x ** 2), calls .block_until_ready() on the result to force synchronisation, then returns the result.

  • x: 1-D JAX array.

Returns: scalar — sum(x**2).

Example (not from the test set):

  • sync_compute(jnp.array([3.0, 4.0])) returns 25.0.

Hints

jax async block-until-ready

Sign in to attempt this problem and view the solution.