We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
DDPM Forward Noising
Implement the closed-form forward process of Denoising Diffusion Probabilistic Models (DDPM).
Background
Diffusion models corrupt a clean input x_0 over T timesteps by adding Gaussian noise.
Naively, this requires iterating through every step. The reparameterization in DDPM
(Ho et al. 2020) shows you can skip directly to any timestep t using alpha_bar[t],
the cumulative product of signal-preservation factors:
q(x_t | x_0) = N(x_t; sqrt(alpha_bar[t]) * x_0, (1 - alpha_bar[t]) * I)
Algorithm
a = sqrt(alpha_bar[t]) # signal scaling
b = sqrt(1 - alpha_bar[t]) # noise scaling
noise = N(0, I) sampled with the given seed
x_t = a * x_0 + b * noise
Why It Works
alpha_bar[t] = prod(1 - beta_1, ..., 1 - beta_t) encodes how much clean signal survives
after t steps. When alpha_bar[t] ≈ 1 (early steps) the output stays close to x_0;
when alpha_bar[t] ≈ 0 (late steps) the output is almost pure noise.
Reference
Ho et al., “Denoising Diffusion Probabilistic Models” (2020).
PRNG note
PyTorch and JAX use different pseudo-random number generators. Given the same
seed, they will produce different samples. The expected outputs for this
problem are generated using PyTorch only. Your JAX solution will be tested
for correctness of the algorithm (right distribution, right shape), not for
exact value matching against PyTorch’s samples.
Inputs / Output
-
x0: tensor of shape(N, d)— clean input. -
t: int — timestep index. -
alpha_bar: tensor of shape(T,)— precomputed cumulative alphas (fromddpm-noise-schedule). -
seed: int — random seed for reproducibility.
Output: x_t of shape (N, d).
Hints
Sign in to attempt this problem and view the solution.