medium primitives

Fix: List Append โ†’ jnp.concatenate

Why this matters

A common porting mistake from NumPy/Python: mutating a list inside a jit-compiled function. Python list operations happen at trace time, not at run time. Inside jax.jit, array values are abstract tracers, so list.append(traced_val) either raises a TracerBoolConversionError or silently bakes in a stale concrete value from the first trace, producing wrong results on subsequent calls.

The BROKEN pattern:

@jax.jit
def broken_append(x, val):
    result = list(x)        # materialises tracer values โ€” leaks!
    result.append(val)
    return jnp.array(result)

Worked mini-example

import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
val = 4.0

# BROKEN: leaks tracers under jit
# result = list(x); result.append(val); return jnp.array(result)

# FIXED: fully JAX-native, jit-safe
out = jnp.concatenate([x, jnp.atleast_1d(val)])
# โ†’ array([1., 2., 3., 4.])

Common pitfalls

  • val may be a 0-D scalar โ€” use jnp.atleast_1d(val) to promote it to rank-1 before concatenating; passing a bare scalar raises a shape error.
  • jnp.concatenate axis defaults to 0 โ€” correct for 1-D arrays; no need to specify it explicitly here.
  • jnp.append(x, val) also works but is a thin wrapper that allocates an intermediate array; jnp.concatenate is the idiomatic JAX form.

Problem

Implement append_via_concat(x, val) that appends scalar val to 1-D array x using jnp.concatenate โ€” the JAX-native, jit-safe way.

  • x: 1-D JAX array.
  • val: scalar.

Returns: 1-D array of length x.shape[0] + 1 with val appended at the end.

Hints

jax tracer-leak fix-broken

Sign in to attempt this problem and view the solution.