We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Sparse Matrix-Vector Product
Why this matters
Sparse matrix-vector multiplication (SpMV) is one of the most common
operations in scientific computing, graph processing, and deep learning.
For a matrix with nnz nonzeros, SpMV costs O(nnz) instead of O(m×n),
which can be orders of magnitude faster when the matrix is sparse.
JAX’s BCOO format supports sparse @ dense via the @ operator natively.
This is used in:
- Graph neural networks — aggregating messages over sparse adjacency.
- Sparse attention — masking most attention weights.
- Physics simulations — finite-element stiffness matrices.
Worked mini-example
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
A = jnp.array([[2.0, 0.0], [0.0, 3.0]])
v = jnp.array([1.0, 2.0])
sp = BCOO.fromdense(A)
result = sp @ v # [2.0, 6.0]
Common pitfalls
-
dense @ sparse—vec @ sp(row-vector times sparse) may not be optimized or supported; prefersp @ vec. -
Result is dense — the output of
sp @ vecis a regular dense array. - Check op support before using sparse — not all JAX ops handle BCOO.
Problem
Implement sparse_matvec(dense_matrix, vec) that converts dense_matrix
to BCOO, then computes the matrix-vector product.
-
dense_matrix: 2-D jax array (m, n). -
vec: 1-D jax array (n,). - Returns: 1-D array (m,).
Hints
jax
sparse
bcoo
matmul
Sign in to attempt this problem and view the solution.