We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
block_until_ready for Sync
Why this matters
JAX dispatches computation asynchronously by default. When you write
result = jnp.sum(x ** 2), JAX submits the work to the accelerator and
returns a future-like JAX array immediately — without waiting for the
hardware to finish. This keeps the Python host busy dispatching the next
operation while the device works in parallel.
The consequence: if you time a JAX computation with Python’s time.time(),
you are measuring dispatch latency (microseconds), not actual compute
time (milliseconds or seconds). Your benchmark looks 100× faster than
reality.
.block_until_ready() forces the host to wait until the device has
finished computing the array’s value. Wrap the call around whatever you
are timing.
Worked mini-example
import time
import jax
import jax.numpy as jnp
x = jnp.ones((10_000,))
# Wrong — measures dispatch only:
t0 = time.perf_counter()
result = jnp.sum(x ** 2)
print(time.perf_counter() - t0) # ≈ 50 µs (lie)
# Correct — measures real compute:
t0 = time.perf_counter()
result = jnp.sum(x ** 2)
result.block_until_ready()
print(time.perf_counter() - t0) # ≈ 1–5 ms (truth)
Common pitfalls
-
Only matters for benchmarking: normal correctness is unaffected —
JAX auto-syncs when you read a value (e.g.,
float(result),print). - Block on the right array: in a chain of ops, block on the last output, not an intermediate one.
-
Calling inside jit is a no-op:
.block_until_ready()is a host-side concept; inside a traced function it has no effect.
Problem
Implement sync_compute(x) that computes jnp.sum(x ** 2), calls
.block_until_ready() on the result to force synchronisation, then
returns the result.
-
x: 1-D JAX array.
Returns: scalar — sum(x**2).
Example (not from the test set):
-
sync_compute(jnp.array([3.0, 4.0]))returns25.0.
Hints
Sign in to attempt this problem and view the solution.