We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Explicit Broadcasting via broadcast_to
Why this matters
JAX (like NumPy) broadcasts automatically in arithmetic, but sometimes you
need to materialize the broadcast for downstream ops โ or you want to
document intent clearly. jnp.broadcast_to(x, target_shape) is the explicit
primitive: it returns a view (zero-copy on CPU; may allocate on accelerator)
of x with shape target_shape.
Use cases include:
-
Tiling a bias vector to match a batch:
jnp.broadcast_to(bias, (B, D)). - Explicit shape documentation in a model architecture.
- Preparing arrays for custom scatter/gather ops that require explicit shapes.
-
Replacing
jnp.tilewhen you only need read access (broadcast is cheaper).
Worked mini-example
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0]) # shape (3,)
y = jnp.broadcast_to(x, (4, 3)) # shape (4, 3)
# y[0] = y[1] = y[2] = y[3] = [1, 2, 3]
Compare with expand_dims + arithmetic:
# Implicit broadcast (same effect, but shape not materialized):
z = jnp.zeros((4, 3)) + x # shape (4, 3) โ same result
Common pitfalls
-
Shape must be a Python tuple of ints:
jnp.broadcast_to(x, jnp_array)fails. Extract withint(target_shape[0]). - Target must be broadcast-compatible: you cannot shrink an axis. Broadcasting only adds or expands axes.
-
broadcast_toreturns a read-only view: writing to it (e.g., via.at[].set()) creates a copy. This is fine but can surprise you if you expect in-place mutation. -
Confusing with
jnp.tile:broadcast_tois a no-copy view;jnp.tilephysically repeats data. Preferbroadcast_towhen you only need reads.
Problem
Implement explicit_broadcast(x, target_shape) that broadcasts a 1-D
array x of shape (W,) to shape (H, W) using jnp.broadcast_to.
-
x: 1-D array of shape(W,). -
target_shape: 1-D array of 2 floats[H, W]โ cast to int. -
Returns: 2-D array of shape
(H, W)where every row equalsx.
Two illustrative examples (not from the test set):
-
x = [5.0, 6.0],target_shape = [3, 2]:out = [[5, 6], [5, 6], [5, 6]] -
x = [0.0],target_shape = [2, 1]:out = [[0], [0]]
Hints
Sign in to attempt this problem and view the solution.