hard primitives

DDPM Denoising Step

Implement one step of the DDPM reverse (denoising) process — given a noisy sample x_t and the model’s noise prediction, produce the denoised estimate x_{t-1}.

Background

The forward DDPM process gradually corrupts x_0 with noise over T steps. The reverse process learns to undo this corruption one step at a time. At each step t, a trained model predicts the noise ε_θ(x_t, t) and we apply a closed-form update to estimate x_{t-1}.

Algorithm

Given the noise schedule betas and cumulative alphas alpha_bar:

alpha_t  = 1 - betas[t]
mean     = (1 / sqrt(alpha_t)) * (x_t - (betas[t] / sqrt(1 - alpha_bar[t])) * predicted_noise)

if t > 0:
    z    = N(0, I)   # sampled with the given seed
    return mean + sqrt(betas[t]) * z
else:
    return mean      # no noise added at the final step

Why It Works

This is the posterior mean of q(x_{t-1} | x_t, x_0) with x_0 estimated from the predicted noise (Ho et al. 2020, Equation 11). The variance term betas[t] re-introduces stochasticity during sampling — except at the last step (t=0), where we return the deterministic mean.

Reference

Ho et al., “Denoising Diffusion Probabilistic Models” (2020), Algorithm 2.

PRNG note

PyTorch and JAX use different pseudo-random number generators. Given the same seed, they will produce different samples. The expected outputs for this problem are generated using PyTorch only. Your JAX solution will be tested for correctness of the algorithm (right distribution, right shape), not for exact value matching against PyTorch’s samples.

Inputs / Output

  • x_t: tensor of shape (N, d) — current noisy sample at timestep t.
  • predicted_noise: tensor of shape (N, d) — model’s noise prediction.
  • t: int — current timestep (0-indexed).
  • betas: tensor of shape (T,) — noise schedule.
  • alpha_bar: tensor of shape (T,) — cumulative alphas.
  • seed: int — random seed for the Gaussian noise (used only when t > 0).

Output: x_{t-1} of shape (N, d).

Hints

generative diffusion ddpm

Sign in to attempt this problem and view the solution.