easy primitives

Explicit Dtype Promotion

Why this matters

JAX follows a strict dtype promotion policy by default. When you mix int32 and float32 in an operation without an explicit cast, the result dtype depends on the JAX configuration (x64 flag) and can silently downcast values. Always cast explicitly — the intent is clear and the behaviour is portable.

# Safe: cast BEFORE the op
result = x_int32.astype(jnp.float32) + y_float32

Worked mini-example

import jax.numpy as jnp

x = jnp.array([1, 2, 3], dtype=jnp.int32)
y = jnp.array([0.5, 0.5, 0.5], dtype=jnp.float32)
result = x.astype(jnp.float32) + y
print(result)        # [1.5 2.5 3.5]
print(result.dtype)  # float32

Common pitfalls

  • Casting AFTER the op: (x_int32 + y_float32).astype(jnp.float32) performs the addition first, potentially in an unexpected dtype.
  • Relying on implicit promotion: JAX’s default rules differ from NumPy and change depending on jax.config.x64. Always be explicit.
  • Casting to the wrong target: cast the integer to float32, not the float to int32 (which would truncate).

Problem

Implement promote_to_float32(x_int32, y_float32) that:

  1. Casts x_int32 to jnp.float32.
  2. Adds it to y_float32.
  3. Returns the 1-D float32 result.
  • x_int32: 1-D JAX array (may be passed as float32 values representing integers).
  • y_float32: 1-D JAX array of the same shape.

Returns: 1-D float32 array.

Examples (not from the test set):

  • promote_to_float32(jnp.array([1, 2]), jnp.array([0.5, 0.5]))[1.5, 2.5]

Hints

jax dtype promotion

Sign in to attempt this problem and view the solution.