We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
DDPM Denoising Step
Implement one step of the DDPM reverse (denoising) process — given a noisy sample
x_t and the model’s noise prediction, produce the denoised estimate x_{t-1}.
Background
The forward DDPM process gradually corrupts x_0 with noise over T steps.
The reverse process learns to undo this corruption one step at a time. At each
step t, a trained model predicts the noise ε_θ(x_t, t) and we apply a closed-form
update to estimate x_{t-1}.
Algorithm
Given the noise schedule betas and cumulative alphas alpha_bar:
alpha_t = 1 - betas[t]
mean = (1 / sqrt(alpha_t)) * (x_t - (betas[t] / sqrt(1 - alpha_bar[t])) * predicted_noise)
if t > 0:
z = N(0, I) # sampled with the given seed
return mean + sqrt(betas[t]) * z
else:
return mean # no noise added at the final step
Why It Works
This is the posterior mean of q(x_{t-1} | x_t, x_0) with x_0 estimated from the
predicted noise (Ho et al. 2020, Equation 11). The variance term betas[t] re-introduces
stochasticity during sampling — except at the last step (t=0), where we return the
deterministic mean.
Reference
Ho et al., “Denoising Diffusion Probabilistic Models” (2020), Algorithm 2.
PRNG note
PyTorch and JAX use different pseudo-random number generators. Given the same
seed, they will produce different samples. The expected outputs for this
problem are generated using PyTorch only. Your JAX solution will be tested
for correctness of the algorithm (right distribution, right shape), not for
exact value matching against PyTorch’s samples.
Inputs / Output
-
x_t: tensor of shape(N, d)— current noisy sample at timestept. -
predicted_noise: tensor of shape(N, d)— model’s noise prediction. -
t: int — current timestep (0-indexed). -
betas: tensor of shape(T,)— noise schedule. -
alpha_bar: tensor of shape(T,)— cumulative alphas. -
seed: int — random seed for the Gaussian noise (used only whent > 0).
Output: x_{t-1} of shape (N, d).
Hints
Sign in to attempt this problem and view the solution.