We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
BCOO from Dense
Why this matters
jax.experimental.sparse.BCOO (Batched Coordinate format) is JAX’s native
sparse array type. BCOO.fromdense(arr) converts a dense array to sparse,
storing only the nonzero indices and values — the rest are implicit zeros.
Used in:
- Graph neural networks — adjacency matrices are sparse.
- Sparse linear systems — finite-element, physics simulations.
- Sparse Jacobians — automatically-computed sparse derivatives.
Key property: .nse (number of stored elements) is the count of nonzeros
in the sparse array.
Worked mini-example
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
dense = jnp.array([[1.0, 0.0], [0.0, 2.0]])
sp = BCOO.fromdense(dense)
print(sp.nse) # 2 (two nonzero entries)
Common pitfalls
-
BCOO is “experimental” — not all JAX ops support sparse inputs; matmul
(
@) does, but many others do not. -
nseis an integer — cast tofloat32if the test contract expects a scalar float. -
All-zero arrays —
nseis 0, which is valid.
Problem
Implement sparse_nnz(dense) that creates a BCOO sparse array from a
2-D dense array and returns the number of nonzero elements as a float32
scalar.
-
dense: 2-D jax array. - Returns: scalar float32 — count of nonzero entries.
Hints
jax
sparse
bcoo
Sign in to attempt this problem and view the solution.