easy primitives

BCOO Round-Trip

Why this matters

sp.todense() converts a BCOO sparse array back to a regular dense array. This round-trip (dense β†’ BCOO β†’ dense) is useful when you need to:

  • Debug β€” inspect intermediate sparse values as a dense array.
  • Interface with dense-only ops β€” some JAX functions don’t accept BCOO; convert back before calling them.
  • Verify correctness β€” confirm sparse construction preserved values.

The round-trip is lossless: nonzero values are stored exactly (no quantization or rounding), and zero entries are reconstructed as 0.0.

Worked mini-example

import jax.numpy as jnp
from jax.experimental.sparse import BCOO

dense = jnp.array([[1.0, 0.0], [0.0, 2.0]])
sp = BCOO.fromdense(dense)
out = sp.todense()
# out == [[1.0, 0.0], [0.0, 2.0]]  β€” identical to input

Common pitfalls

  • Memory cost β€” dense output allocates O(mΓ—n) memory regardless of sparsity; avoid calling .todense() on huge matrices.
  • Hot loops β€” round-tripping inside a jit-ted loop defeats the purpose of using sparse; keep the sparse representation for as long as possible.
  • Shape preservation β€” .todense() restores the original shape exactly.

Problem

Implement sparse_roundtrip(dense) that converts a dense array to BCOO and immediately back to dense, returning the reconstructed array.

  • dense: 2-D jax array.
  • Returns: 2-D array with the same shape and values as dense.

Hints

jax sparse bcoo todense

Sign in to attempt this problem and view the solution.