medium primitives

BCOO from Dense

Why this matters

jax.experimental.sparse.BCOO (Batched Coordinate format) is JAX’s native sparse array type. BCOO.fromdense(arr) converts a dense array to sparse, storing only the nonzero indices and values — the rest are implicit zeros.

Used in:

  • Graph neural networks — adjacency matrices are sparse.
  • Sparse linear systems — finite-element, physics simulations.
  • Sparse Jacobians — automatically-computed sparse derivatives.

Key property: .nse (number of stored elements) is the count of nonzeros in the sparse array.

Worked mini-example

import jax.numpy as jnp
from jax.experimental.sparse import BCOO

dense = jnp.array([[1.0, 0.0], [0.0, 2.0]])
sp = BCOO.fromdense(dense)
print(sp.nse)  # 2  (two nonzero entries)

Common pitfalls

  • BCOO is “experimental” — not all JAX ops support sparse inputs; matmul (@) does, but many others do not.
  • nse is an integer — cast to float32 if the test contract expects a scalar float.
  • All-zero arraysnse is 0, which is valid.

Problem

Implement sparse_nnz(dense) that creates a BCOO sparse array from a 2-D dense array and returns the number of nonzero elements as a float32 scalar.

  • dense: 2-D jax array.
  • Returns: scalar float32 — count of nonzero entries.

Hints

jax sparse bcoo

Sign in to attempt this problem and view the solution.