We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train VAE End-to-End
Perform one full VAE training step on a fully-connected encoder/decoder β forward pass, ELBO loss, and hand-computed gradients via the chain rule.
No torch.autograd. No loss.backward(). You derive every partial derivative yourself.
Architecture
A minimal linear VAE (no activation functions, for gradient-checking clarity):
-
Encoder:
h = x @ w_encgives(mu, log_var)concatenated along the last dimension. Firstd_zcolumns ofharemu; lastd_zcolumns arelog_var. -
Reparameterization:
sigma = exp(0.5 * log_var), sampleeps ~ N(0, I)with the given seed, thenz = mu + sigma * eps. -
Decoder:
x_hat = z @ w_dec.
Loss (negative ELBO)
recon_loss = ((x_hat - x)**2).sum(dim=-1).mean() # mean over batch, sum over d
kl_loss = -0.5 * (1 + log_var - mu**2 - exp(log_var)).sum(dim=-1).mean()
loss = recon_loss + kl_loss
The KL term has a closed form for q(z|x) = N(mu, diag(exp(log_var))) vs p(z) = N(0, I).
Reference: Kingma & Welling, βAuto-Encoding Variational Bayesβ (2013).
Backward pass (by hand)
d_x_hat = (2/N) * (x_hat - x) # (N, d)
grad_w_dec = z.T @ d_x_hat # (d_z, d)
d_z = d_x_hat @ w_dec.T # (N, d_z)
# Through reparameterization z = mu + sigma * eps:
d_mu_recon = d_z
d_log_var_recon = d_z * 0.5 * sigma * eps # dz/d_log_var = 0.5 * sigma * eps
# KL contributions:
d_mu_kl = mu / N # from -0.5 * (-2*mu) / N
d_log_var_kl = (exp(log_var) - 1) / (2*N) # from -0.5 * (1 - exp(log_var)) / N
d_mu = d_mu_recon + d_mu_kl
d_log_var = d_log_var_recon + d_log_var_kl
d_h = concat(d_mu, d_log_var, dim=-1) # (N, 2*d_z)
grad_w_enc = x.T @ d_h # (d, 2*d_z)
w_enc -= lr * grad_w_enc
w_dec -= lr * grad_w_dec
Inputs
-
x: shape(N, d)β input batch. -
w_enc: shape(d, 2*d_z)β encoder weights; columns[:d_z]map tomu,[d_z:]tolog_var. -
w_dec: shape(d_z, d)β decoder weights. -
lr: float β SGD learning rate. -
seed: int β for reproducible reparameterization noise.
Output
Returns shape (d * 2*d_z + d_z * d,) β concatenation of (final_w_enc.flatten(), final_w_dec.flatten()).
PRNG note
PyTorch and JAX use different pseudo-random number generators. Given the same
seed, they produce different eps samples, so their updated weights diverge.
The expected outputs here are generated using PyTorch only (framework: :pytorch).
Hints
Sign in to attempt this problem and view the solution.