medium end_to_end

DDPM Training Step End-to-End

Implement one complete DDPM training step β€” forward noising, noise prediction, and a hand-computed gradient update β€” for a tiny linear noise-prediction model.

No torch.autograd. No loss.backward(). Every partial derivative is derived by hand.

Model

A minimal linear noise predictor (no activations, for gradient-checking clarity):

input = concat([x_t, t_normalized])    # shape (N, d+1)
predicted_noise = input @ w_model      # shape (N, d)

where w_model has shape (d+1, d).

Pipeline

  1. Sample timestep t uniformly from [0, T):
    gen_t = torch.Generator().manual_seed(int(seed))
    t = int(torch.randint(0, T, (1,), generator=gen_t).item())
  2. Sample noise eps ~ N(0, I) with shape (N, d):
    gen_eps = torch.Generator().manual_seed(int(seed) + 1)
    eps = torch.randn(N, d, generator=gen_eps)
  3. Forward noising (DDPM q(x_t | x_0)):
    x_t = sqrt(alpha_bar[t]) * x_clean + sqrt(1 - alpha_bar[t]) * eps
  4. Build model input:
    t_norm = t / T
    input = concat([x_t, t_norm * ones(N, 1)], dim=-1)   # (N, d+1)
  5. Predict: predicted_noise = input @ w_model
  6. MSE loss: mean((predicted_noise - eps)**2)
  7. Gradient by hand:
    d_predicted = (2 / (N * d)) * (predicted_noise - eps)   # (N, d)
    grad_w_model = input.T @ d_predicted                     # (d+1, d)
  8. SGD update: w_model -= lr * grad_w_model
  9. Return w_model.flatten().

Inputs

  • x_clean: shape (N, d) β€” clean training samples.
  • w_model: shape (d+1, d) β€” noise-prediction model weights.
  • betas: shape (T,) β€” diffusion noise schedule.
  • alpha_bar: shape (T,) β€” cumulative product of (1 - betas).
  • lr: float β€” SGD learning rate.
  • seed: int β€” used for both timestep and noise sampling.

Output

Returns shape ((d+1)*d,) β€” w_model.flatten() after the SGD update.

PRNG note

PyTorch and JAX use different pseudo-random number generators. Expected outputs here are generated using PyTorch only (framework: :pytorch). Reference: Ho et al., β€œDenoising Diffusion Probabilistic Models” (2020).

Hints

generative diffusion ddpm training

Sign in to attempt this problem and view the solution.