medium primitives

1-D FFT Magnitude

Why this matters

jnp.fft.fft(x) computes the 1-D Discrete Fourier Transform (DFT), returning complex Fourier coefficients that decompose a signal into its constituent frequencies. Taking jnp.abs(...) yields the magnitude spectrum โ€” the amplitude of each frequency bin.

FFTs are central to:

  • Signal processing โ€” identify dominant frequencies in audio or sensor data.
  • Spectral analysis โ€” power spectrum, frequency filtering.
  • Convolutions โ€” convolve in frequency domain for O(n log n) vs O(nยฒ).

Structural facts worth memorising:

  • Bin 0 (DC) = sum(x).
  • For real x, output has Hermitian symmetry โ€” bins above n//2 mirror bins below; use rfft to halve the work (next problem).
  • Output length equals input length.

Worked mini-example

import jax.numpy as jnp

x = jnp.array([1.0, 0.0, 0.0, 0.0])   # impulse at t=0
mags = jnp.abs(jnp.fft.fft(x))
# mags = [1.0, 1.0, 1.0, 1.0]          # flat spectrum โ€” all freqs equal

Common pitfalls

  • Complex output โ€” fft returns complex numbers; you must call jnp.abs(...) to get magnitudes.
  • DC bin โ€” mags[0] is the zero-frequency (average) component = sum(x). A constant signal has mags[0] = N*mean(x), all other bins 0.
  • Use rfft for real signals โ€” fft on real input computes redundant conjugate bins; rfft is ~2ร— faster.

Problem

Implement fft_magnitude(x) that returns the magnitude of each FFT bin.

  • x: 1-D jax array.
  • Returns: 1-D array, same shape โ€” magnitudes of complex FFT coefficients.

Hints

jax fft

Sign in to attempt this problem and view the solution.