We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Fix: List Append โ jnp.concatenate
Why this matters
A common porting mistake from NumPy/Python: mutating a list inside a
jit-compiled function. Python list operations happen at trace time, not
at run time. Inside jax.jit, array values are abstract tracers, so
list.append(traced_val) either raises a TracerBoolConversionError or
silently bakes in a stale concrete value from the first trace, producing
wrong results on subsequent calls.
The BROKEN pattern:
@jax.jit
def broken_append(x, val):
result = list(x) # materialises tracer values โ leaks!
result.append(val)
return jnp.array(result)
Worked mini-example
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
val = 4.0
# BROKEN: leaks tracers under jit
# result = list(x); result.append(val); return jnp.array(result)
# FIXED: fully JAX-native, jit-safe
out = jnp.concatenate([x, jnp.atleast_1d(val)])
# โ array([1., 2., 3., 4.])
Common pitfalls
-
valmay be a 0-D scalar โ usejnp.atleast_1d(val)to promote it to rank-1 before concatenating; passing a bare scalar raises a shape error. -
jnp.concatenateaxis defaults to 0 โ correct for 1-D arrays; no need to specify it explicitly here. -
jnp.append(x, val)also works but is a thin wrapper that allocates an intermediate array;jnp.concatenateis the idiomatic JAX form.
Problem
Implement append_via_concat(x, val) that appends scalar val to 1-D
array x using jnp.concatenate โ the JAX-native, jit-safe way.
-
x: 1-D JAX array. -
val: scalar.
Returns: 1-D array of length x.shape[0] + 1 with val appended at the end.
Hints
Sign in to attempt this problem and view the solution.