We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement Transposed Convolution
Why this matters
Transposed convolution (a.k.a. “deconvolution” — though that’s a misnomer) is the dual of regular conv: where Conv2D downsamples (or preserves) spatial size, ConvTranspose upsamples. It’s the workhorse of:
- Image generation (GANs, VAEs, diffusion U-Nets) — upsample latent maps to pixel space.
- Semantic segmentation (U-Net decoder) — recover input resolution after a contractive encoder.
- Audio synthesis — wavenet-like architectures.
Mathematically, transposed conv shares its weight tensor with regular conv but applies it in the “reverse” direction (the transpose of the matrix you’d write down for the regular conv).
What stride means here (different!)
For regular conv, stride=2 makes the output smaller (downsample).
For transposed conv, stride=2 makes the output larger (upsample
by 2×).
With padding="SAME" and stride=s:
-
Regular conv:
H_out = ceil(H_in / s). -
Transposed conv:
H_out = H_in * s.
Common pattern in U-Nets: ConvTranspose with stride=2 doubles spatial
size at each decoder stage.
API: jax.lax.conv_transpose
JAX has a dedicated primitive: jax.lax.conv_transpose.
y = jax.lax.conv_transpose(
x[None, ...], # add batch dim
kernel, # WIO layout (1-D) or HWIO (2-D)
strides=(stride,),
padding="SAME",
dimension_numbers=("NWC", "WIO", "NWC"),
)
Note the kwarg name is strides= (not window_strides=) — different
from conv_general_dilated. Annoying but consistent within JAX.
Worked 1-D example
x = jnp.array([1., 2., 3., 4.])[:, None] # (4, 1)
kernel = jnp.ones((3, 1, 1)) # WIO: w=3, in=1, out=1
y = jax.lax.conv_transpose(
x[None, ...],
kernel,
strides=(2,),
padding="SAME",
dimension_numbers=("NWC", "WIO", "NWC"),
)
# y shape: (1, 8, 1) — input length 4, stride 2 → output length 8.
Why “transposed” and not “deconvolution”?
A true deconvolution would invert the convolution operation — you’d need to solve a linear system. Transposed conv shares the weight matrix’s transpose with the forward conv, but does NOT invert it. The output is a new feature map, not the original input recovered.
The “transposed” name comes from the matrix view: if you express conv as
y = W @ x with W a sparse Toeplitz matrix, then transposed conv is
x' = W.T @ y — same W, transposed.
Common pitfalls
-
Wrong kwarg name:
strides=forconv_transpose,window_strides=forconv_general_dilated. Mix them up and you’ll get a TypeError. -
Expecting “true deconvolution”: transposed conv is NOT an inverse.
Don’t expect
conv(transpose_conv(x)) == x. -
Same kernel layout (WIO/HWIO) but interpretation flips: in
transposed conv, the
Idim is the input to the transposed op (= output of the forward conv) andOis the output. JAX’s API handles this automatically when you say “transpose”; you just need to knowkernel[..., I, O]matches the layout.
Problem
Implement MyTransposedConv(features, kernel_size, stride) for 1-D
inputs:
-
Kernel shape
(kernel_size, x.shape[-1], features), initlecun_normal(). -
Bias shape
(features,), initzeros. -
Add batch dim, call
jax.lax.conv_transposewithstrides=(stride,),padding="SAME",dimension_numbers=("NWC", "WIO", "NWC"). - Drop batch dim, add bias.
-
Return
.reshape(-1).
Inputs:
-
seed,features,kernel_size,stride: floats (cast to int). -
x: 2-D(L, C_in).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.