medium primitives

DDPM Forward Noising

Implement the closed-form forward process of Denoising Diffusion Probabilistic Models (DDPM).

Background

Diffusion models corrupt a clean input x_0 over T timesteps by adding Gaussian noise. Naively, this requires iterating through every step. The reparameterization in DDPM (Ho et al. 2020) shows you can skip directly to any timestep t using alpha_bar[t], the cumulative product of signal-preservation factors:

q(x_t | x_0) = N(x_t; sqrt(alpha_bar[t]) * x_0, (1 - alpha_bar[t]) * I)

Algorithm

a     = sqrt(alpha_bar[t])          # signal scaling
b     = sqrt(1 - alpha_bar[t])      # noise scaling
noise = N(0, I)  sampled with the given seed
x_t   = a * x_0 + b * noise

Why It Works

alpha_bar[t] = prod(1 - beta_1, ..., 1 - beta_t) encodes how much clean signal survives after t steps. When alpha_bar[t] ≈ 1 (early steps) the output stays close to x_0; when alpha_bar[t] ≈ 0 (late steps) the output is almost pure noise.

Reference

Ho et al., “Denoising Diffusion Probabilistic Models” (2020).

PRNG note

PyTorch and JAX use different pseudo-random number generators. Given the same seed, they will produce different samples. The expected outputs for this problem are generated using PyTorch only. Your JAX solution will be tested for correctness of the algorithm (right distribution, right shape), not for exact value matching against PyTorch’s samples.

Inputs / Output

  • x0: tensor of shape (N, d) — clean input.
  • t: int — timestep index.
  • alpha_bar: tensor of shape (T,) — precomputed cumulative alphas (from ddpm-noise-schedule).
  • seed: int — random seed for reproducibility.

Output: x_t of shape (N, d).

Hints

generative diffusion ddpm

Sign in to attempt this problem and view the solution.