We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Dtype Promotion Rules
Why this matters
JAX’s default precision is float32 (single-precision). Even if you write
jnp.array(1.0) (which would be float64 in NumPy), JAX promotes it to
float32. This is intentional: ML workloads are dominated by float32, and
double-precision is wasteful on accelerators. To enable float64 you must set
jax.config.update("jax_enable_x64", True) AT IMPORT TIME — not later.
The promotion rules also differ subtly from NumPy:
-
In NumPy:
int32 + float64 → float64(float wins). -
In JAX (x64 disabled, default):
int32 + float64 → float32(everything is float32). -
In JAX (x64 enabled): matches NumPy —
int32 + float64 → float64.
Knowing this prevents silent precision loss. If you intend a float32
accumulator and accidentally bring in a float64 array (e.g., from NumPy or
PyTorch), you can get float64 — slow on TPU, OOM-prone on long training runs.
The fix is explicit casts: arr.astype(jnp.float32). Make precision a
contract, not an inference.
Worked mini-example
import jax, jax.numpy as jnp
# Default config (x64 disabled):
a = jnp.array([1, 2, 3], dtype=jnp.int32)
b = jnp.array([0.5, 0.5, 0.5], dtype=jnp.float32)
(a + b).dtype # → float32
# If somehow b were float64:
b64 = jnp.array([0.5, 0.5, 0.5], dtype=jnp.float64)
# In default config (x64 disabled), b64 is silently downcast to float32.
# In x64-enabled config, (a + b64).dtype would be float64.
# Best practice: cast both inputs explicitly:
result = (a.astype(jnp.float32) + b.astype(jnp.float32)).astype(jnp.float32)
Common pitfalls
-
Forgetting
jax.config.update("jax_enable_x64", True): silently downcasts float64 to float32. Set it at the very top of your script. -
Mixing JAX and NumPy arrays: NumPy returns float64 by default; JAX’s
jnp.array(np_arr)will downcast unless x64 is enabled. -
jnp.float32(x)vsx.astype(jnp.float32): the first creates a 0-D scalar; the second casts an existing array. -
jnp.int32vs Pythonint: indices in.at[]must be JAXint32(or compatible). Cast with.astype(jnp.int32). -
TPU precision: TPU does many float32 ops in bfloat16 internally. For
numerically sensitive code, use
jax.lax.Precision.HIGHESTorjnp.bfloat16explicitly.
Problem
Implement force_float32(x_int, x_float64) that adds two 1-D arrays and
returns a 1-D float32 array — the sum, explicitly cast.
The test harness delivers both inputs as floats. Treat x_int as holding
integer values (conceptually int32) and x_float64 as holding
double-precision values. Your function must produce a float32 result
regardless of whether JAX’s x64 mode is enabled.
Two illustrative examples (not from the test set):
-
x_int = [5.0, 10.0],x_float64 = [0.25, 0.75]: Sum = [5.25, 10.75]. Cast to float32 →array([5.25, 10.75], dtype=float32) -
x_int = [0.0, -3.0, 7.0],x_float64 = [1.1, 2.2, 3.3]: Sum = [1.1, -0.8, 10.3]. Cast to float32 →array([1.1, -0.8, 10.3], dtype=float32)
Hints
Sign in to attempt this problem and view the solution.