easy primitives

Explicit Broadcasting via broadcast_to

Why this matters

JAX (like NumPy) broadcasts automatically in arithmetic, but sometimes you need to materialize the broadcast for downstream ops โ€” or you want to document intent clearly. jnp.broadcast_to(x, target_shape) is the explicit primitive: it returns a view (zero-copy on CPU; may allocate on accelerator) of x with shape target_shape.

Use cases include:

  • Tiling a bias vector to match a batch: jnp.broadcast_to(bias, (B, D)).
  • Explicit shape documentation in a model architecture.
  • Preparing arrays for custom scatter/gather ops that require explicit shapes.
  • Replacing jnp.tile when you only need read access (broadcast is cheaper).

Worked mini-example

import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])          # shape (3,)
y = jnp.broadcast_to(x, (4, 3))          # shape (4, 3)
# y[0] = y[1] = y[2] = y[3] = [1, 2, 3]

Compare with expand_dims + arithmetic:

# Implicit broadcast (same effect, but shape not materialized):
z = jnp.zeros((4, 3)) + x    # shape (4, 3) โ€” same result

Common pitfalls

  • Shape must be a Python tuple of ints: jnp.broadcast_to(x, jnp_array) fails. Extract with int(target_shape[0]).
  • Target must be broadcast-compatible: you cannot shrink an axis. Broadcasting only adds or expands axes.
  • broadcast_to returns a read-only view: writing to it (e.g., via .at[].set()) creates a copy. This is fine but can surprise you if you expect in-place mutation.
  • Confusing with jnp.tile: broadcast_to is a no-copy view; jnp.tile physically repeats data. Prefer broadcast_to when you only need reads.

Problem

Implement explicit_broadcast(x, target_shape) that broadcasts a 1-D array x of shape (W,) to shape (H, W) using jnp.broadcast_to.

  • x: 1-D array of shape (W,).
  • target_shape: 1-D array of 2 floats [H, W] โ€” cast to int.
  • Returns: 2-D array of shape (H, W) where every row equals x.

Two illustrative examples (not from the test set):

  • x = [5.0, 6.0], target_shape = [3, 2]: out = [[5, 6], [5, 6], [5, 6]]

  • x = [0.0], target_shape = [2, 1]: out = [[0], [0]]

Hints

jax broadcasting shape-manipulation

Sign in to attempt this problem and view the solution.