medium primitives

slogdet for Stable log|det|

Why this matters

Computing the log-determinant is a core operation in probabilistic ML: Gaussian log-likelihoods, normalising flows, and Bayesian inference all require log|det(Ξ£)| for covariance matrices that may be huge or tiny.

jnp.linalg.det(A) returns the raw determinant, which overflows for large matrices and underflows to zero for large negative eigenvalues. jnp.linalg.slogdet(A) sidesteps this by returning (sign, log|det|) so you stay in log-space throughout β€” no overflow, no underflow.

Worked mini-example

import jax.numpy as jnp

A = jnp.array([[2.0, 0.0],
               [0.0, 3.0]])

sign, logabsdet = jnp.linalg.slogdet(A)
# sign       β†’ 1.0   (det = 6 > 0)
# logabsdet  β†’ 1.7918…  (log(6))

# Equivalent but numerically unsafe:
# jnp.log(jnp.abs(jnp.linalg.det(A)))  ← overflows for big matrices

Common pitfalls

  • det overflows/underflows for large n β€” always prefer slogdet for log-likelihoods.
  • Sign is separate β€” if you need the signed log-determinant, keep the sign and add jnp.log(sign) (only valid when sign β‰  0).
  • Singular A returns sign=0, logabsdet=-inf β€” check before using.
  • log(abs(det(A))) is the unsafe version; don’t use it.

Problem

Implement log_abs_det(A) that returns log(|det(A)|) using jnp.linalg.slogdet.

  • A: 2-D jax array (n, n) β€” square, non-singular.
  • Returns: scalar β€” the log of the absolute determinant.

Hints

jax linalg det slogdet

Sign in to attempt this problem and view the solution.