hard primitives

Dtype Promotion Rules

Why this matters

JAX’s default precision is float32 (single-precision). Even if you write jnp.array(1.0) (which would be float64 in NumPy), JAX promotes it to float32. This is intentional: ML workloads are dominated by float32, and double-precision is wasteful on accelerators. To enable float64 you must set jax.config.update("jax_enable_x64", True) AT IMPORT TIME — not later.

The promotion rules also differ subtly from NumPy:

  • In NumPy: int32 + float64 → float64 (float wins).
  • In JAX (x64 disabled, default): int32 + float64 → float32 (everything is float32).
  • In JAX (x64 enabled): matches NumPy — int32 + float64 → float64.

Knowing this prevents silent precision loss. If you intend a float32 accumulator and accidentally bring in a float64 array (e.g., from NumPy or PyTorch), you can get float64 — slow on TPU, OOM-prone on long training runs. The fix is explicit casts: arr.astype(jnp.float32). Make precision a contract, not an inference.

Worked mini-example

import jax, jax.numpy as jnp

# Default config (x64 disabled):
a = jnp.array([1, 2, 3], dtype=jnp.int32)
b = jnp.array([0.5, 0.5, 0.5], dtype=jnp.float32)
(a + b).dtype          # → float32

# If somehow b were float64:
b64 = jnp.array([0.5, 0.5, 0.5], dtype=jnp.float64)
# In default config (x64 disabled), b64 is silently downcast to float32.
# In x64-enabled config, (a + b64).dtype would be float64.

# Best practice: cast both inputs explicitly:
result = (a.astype(jnp.float32) + b.astype(jnp.float32)).astype(jnp.float32)

Common pitfalls

  • Forgetting jax.config.update("jax_enable_x64", True): silently downcasts float64 to float32. Set it at the very top of your script.
  • Mixing JAX and NumPy arrays: NumPy returns float64 by default; JAX’s jnp.array(np_arr) will downcast unless x64 is enabled.
  • jnp.float32(x) vs x.astype(jnp.float32): the first creates a 0-D scalar; the second casts an existing array.
  • jnp.int32 vs Python int: indices in .at[] must be JAX int32 (or compatible). Cast with .astype(jnp.int32).
  • TPU precision: TPU does many float32 ops in bfloat16 internally. For numerically sensitive code, use jax.lax.Precision.HIGHEST or jnp.bfloat16 explicitly.

Problem

Implement force_float32(x_int, x_float64) that adds two 1-D arrays and returns a 1-D float32 array — the sum, explicitly cast.

The test harness delivers both inputs as floats. Treat x_int as holding integer values (conceptually int32) and x_float64 as holding double-precision values. Your function must produce a float32 result regardless of whether JAX’s x64 mode is enabled.

Two illustrative examples (not from the test set):

  • x_int = [5.0, 10.0], x_float64 = [0.25, 0.75]: Sum = [5.25, 10.75]. Cast to float32 → array([5.25, 10.75], dtype=float32)

  • x_int = [0.0, -3.0, 7.0], x_float64 = [1.1, 2.2, 3.3]: Sum = [1.1, -0.8, 10.3]. Cast to float32 → array([1.1, -0.8, 10.3], dtype=float32)

Hints

jax dtype precision

Sign in to attempt this problem and view the solution.