hard primitives

Sparse Matrix-Vector Product

Why this matters

Sparse matrix-vector multiplication (SpMV) is one of the most common operations in scientific computing, graph processing, and deep learning. For a matrix with nnz nonzeros, SpMV costs O(nnz) instead of O(m×n), which can be orders of magnitude faster when the matrix is sparse.

JAX’s BCOO format supports sparse @ dense via the @ operator natively. This is used in:

  • Graph neural networks — aggregating messages over sparse adjacency.
  • Sparse attention — masking most attention weights.
  • Physics simulations — finite-element stiffness matrices.

Worked mini-example

import jax.numpy as jnp
from jax.experimental.sparse import BCOO

A = jnp.array([[2.0, 0.0], [0.0, 3.0]])
v = jnp.array([1.0, 2.0])
sp = BCOO.fromdense(A)
result = sp @ v  # [2.0, 6.0]

Common pitfalls

  • dense @ sparsevec @ sp (row-vector times sparse) may not be optimized or supported; prefer sp @ vec.
  • Result is dense — the output of sp @ vec is a regular dense array.
  • Check op support before using sparse — not all JAX ops handle BCOO.

Problem

Implement sparse_matvec(dense_matrix, vec) that converts dense_matrix to BCOO, then computes the matrix-vector product.

  • dense_matrix: 2-D jax array (m, n).
  • vec: 1-D jax array (n,).
  • Returns: 1-D array (m,).

Hints

jax sparse bcoo matmul

Sign in to attempt this problem and view the solution.