We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
easy
primitives
Explicit Dtype Promotion
Why this matters
JAX follows a strict dtype promotion policy by default. When you mix
int32 and float32 in an operation without an explicit cast, the result
dtype depends on the JAX configuration (x64 flag) and can silently
downcast values. Always cast explicitly — the intent is clear and the
behaviour is portable.
# Safe: cast BEFORE the op
result = x_int32.astype(jnp.float32) + y_float32
Worked mini-example
import jax.numpy as jnp
x = jnp.array([1, 2, 3], dtype=jnp.int32)
y = jnp.array([0.5, 0.5, 0.5], dtype=jnp.float32)
result = x.astype(jnp.float32) + y
print(result) # [1.5 2.5 3.5]
print(result.dtype) # float32
Common pitfalls
-
Casting AFTER the op:
(x_int32 + y_float32).astype(jnp.float32)performs the addition first, potentially in an unexpected dtype. -
Relying on implicit promotion: JAX’s default rules differ from NumPy
and change depending on
jax.config.x64. Always be explicit. - Casting to the wrong target: cast the integer to float32, not the float to int32 (which would truncate).
Problem
Implement promote_to_float32(x_int32, y_float32) that:
-
Casts
x_int32tojnp.float32. -
Adds it to
y_float32. - Returns the 1-D float32 result.
-
x_int32: 1-D JAX array (may be passed as float32 values representing integers). -
y_float32: 1-D JAX array of the same shape.
Returns: 1-D float32 array.
Examples (not from the test set):
-
promote_to_float32(jnp.array([1, 2]), jnp.array([0.5, 0.5]))→[1.5, 2.5]
Hints
jax
dtype
promotion
Sign in to attempt this problem and view the solution.