hard primitives

2-D FFT DC Component

Why this matters

jnp.fft.fft2(x) computes the 2-D Discrete Fourier Transform โ€” the natural extension of the 1-D FFT to matrices and images. It applies 1-D FFTs along each dimension independently (rows then columns), decomposing the input into 2-D spatial frequencies.

Applications:

  • Image processing โ€” low-pass filtering (blur), high-pass (edge detection), frequency-domain convolution.
  • Crystallography and physics โ€” diffraction patterns, structure factors.
  • Deep learning โ€” spectral pooling, frequency-domain attention.

The DC component fft2(x)[0, 0] is the zero-frequency bin in both dimensions โ€” it equals sum(x) (the sum of all elements). This is the 2-D analogue of the 1-D DC bin being sum(x).

Worked mini-example

import jax.numpy as jnp

x = jnp.array([[1.0, 2.0],
               [3.0, 4.0]])

dc = jnp.abs(jnp.fft.fft2(x)[0, 0])
# dc = 10.0   (= 1 + 2 + 3 + 4)

Common pitfalls

  • Output is complex โ€” fft2 returns complex numbers; take jnp.abs for magnitude.
  • Indexing โ€” DC component is [0, 0], not [-1, -1].
  • For real images use rfft2 โ€” output shape (m, n//2+1), roughly half the storage; use irfft2 for the inverse.
  • DC equals sum(x) โ€” a useful sanity check: fft2(x)[0, 0] must equal x.sum() up to floating-point precision.

Problem

Implement fft_2d_dc(x) that returns the magnitude of the DC component of the 2-D FFT.

  • x: 2-D jax array.
  • Returns: scalar โ€” |fft2(x)[0, 0]|, which equals |sum(x)|.

Hints

jax fft fft2

Sign in to attempt this problem and view the solution.