We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
log1p / expm1 for Small Values
Why this matters
Floating-point arithmetic loses precision when you add a tiny number to 1.0
before taking a log. In float32, 1.0 + 1e-7 == 1.0 โ the small number is
completely swallowed. So log(1.0 + 1e-7) returns exactly 0.0 instead of
โ 1e-7. This silent precision loss corrupts probability calculations and
financial math.
Two standard-library functions fix this:
| Function | Computes | Accurate for |
|---|---|---|
jnp.log1p(x) |
log(1 + x) |
small x (avoids 1 + x cancellation) |
jnp.expm1(x) |
exp(x) - 1 |
small x (avoids exp(x) - 1 cancellation) |
They are mathematical inverses: log1p(expm1(x)) == x to full float precision
for any finite x.
Worked mini-example
import jax.numpy as jnp
x = jnp.array(1e-7)
# Naive: loses all precision in float32
naive = jnp.log(1.0 + x) # == 0.0 (WRONG)
# Stable: accurate to machine epsilon
good = jnp.log1p(x) # โ 1e-7 (correct)
Round-trip check (should recover x exactly):
x = jnp.array([0.0, 0.001, 0.5])
recovered = jnp.log1p(jnp.expm1(x))
# โ [0.0, 0.001, 0.5] โ identity to float32 precision
Common pitfalls
-
log(1 + x)for tiny x โ1 + xis computed first in float32, losing all significant digits whenx < 1e-7. -
exp(x) - 1for tiny x โexp(x)is very close to 1, so the subtraction cancels most significant bits. -
Only meaningful for small x โ for large x (e.g., x = 10),
log1pandlog(1 + x)agree; the benefit is only near x โ 0.
Problem
Implement small_value_safe(x) that round-trips x through log1p(expm1(x)).
-
x: 1-D jax array. - Returns: 1-D array same shape โ should equal x to high precision.
Hints
Sign in to attempt this problem and view the solution.