We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
DDPM Training Step End-to-End
Implement one complete DDPM training step β forward noising, noise prediction, and a hand-computed gradient update β for a tiny linear noise-prediction model.
No torch.autograd. No loss.backward(). Every partial derivative is derived by hand.
Model
A minimal linear noise predictor (no activations, for gradient-checking clarity):
input = concat([x_t, t_normalized]) # shape (N, d+1)
predicted_noise = input @ w_model # shape (N, d)
where w_model has shape (d+1, d).
Pipeline
-
Sample timestep
tuniformly from[0, T):gen_t = torch.Generator().manual_seed(int(seed)) t = int(torch.randint(0, T, (1,), generator=gen_t).item()) -
Sample noise
eps ~ N(0, I)with shape(N, d):gen_eps = torch.Generator().manual_seed(int(seed) + 1) eps = torch.randn(N, d, generator=gen_eps) -
Forward noising (DDPM
q(x_t | x_0)):x_t = sqrt(alpha_bar[t]) * x_clean + sqrt(1 - alpha_bar[t]) * eps -
Build model input:
t_norm = t / T input = concat([x_t, t_norm * ones(N, 1)], dim=-1) # (N, d+1) -
Predict:
predicted_noise = input @ w_model -
MSE loss:
mean((predicted_noise - eps)**2) -
Gradient by hand:
d_predicted = (2 / (N * d)) * (predicted_noise - eps) # (N, d) grad_w_model = input.T @ d_predicted # (d+1, d) -
SGD update:
w_model -= lr * grad_w_model -
Return
w_model.flatten().
Inputs
-
x_clean: shape(N, d)β clean training samples. -
w_model: shape(d+1, d)β noise-prediction model weights. -
betas: shape(T,)β diffusion noise schedule. -
alpha_bar: shape(T,)β cumulative product of(1 - betas). -
lr: float β SGD learning rate. -
seed: int β used for both timestep and noise sampling.
Output
Returns shape ((d+1)*d,) β w_model.flatten() after the SGD update.
PRNG note
PyTorch and JAX use different pseudo-random number generators. Expected outputs
here are generated using PyTorch only (framework: :pytorch).
Reference: Ho et al., βDenoising Diffusion Probabilistic Modelsβ (2020).
Hints
Sign in to attempt this problem and view the solution.