We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
easy
primitives
jnp.allclose for Tolerance Equality
Why this matters
Floating-point arithmetic makes exact equality comparisons unreliable.
jnp.allclose(a, b, atol=k) returns True when every element satisfies
|a - b| <= atol + rtol * |b|. This is the standard JAX idiom for
asserting numerical correctness in tests, comparing two implementations of
the same function, or verifying solver convergence.
Worked mini-example
import jax.numpy as jnp
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([1.0001, 2.0001, 3.0001])
jnp.allclose(a, b, atol=1e-3) # True β within 0.001
jnp.allclose(a, b, atol=1e-5) # False β not within 0.00001
Combine with jnp.where for a jit-safe conversion from boolean to float:
jnp.where(jnp.allclose(a, b, atol=1e-3), 1.0, 0.0) # β 1.0
Common pitfalls
-
Donβt use
==for float comparisons βa == bon float32 arrays fails for any rounding mismatch, even within acceptable tolerances. -
rtolalso matters β the defaultrtol=1e-5adds a relative termrtol * |b|; for very small magnitudes set bothatolandrtolexplicitly. -
jnp.allclosereturns a scalar boolean β insidejax.jityou cannot use it in a Pythonifstatement. Usejnp.whereinstead. -
Shape mismatch raises an error β
aandbmust be broadcastable.
Problem
Implement values_equal_within_atol(a, b, atol) that:
-
Calls
jnp.allclose(a, b, atol=atol). -
Returns
1.0if close,0.0otherwise β usingjnp.where.
-
a,b: 1-D JAX arrays of the same shape. -
atol: scalar absolute tolerance.
Returns: scalar β 1.0 or 0.0.
Hints
jax
allclose
asserts
Sign in to attempt this problem and view the solution.