medium primitives

Implement Transposed Convolution

Why this matters

Transposed convolution (a.k.a. “deconvolution” — though that’s a misnomer) is the dual of regular conv: where Conv2D downsamples (or preserves) spatial size, ConvTranspose upsamples. It’s the workhorse of:

  • Image generation (GANs, VAEs, diffusion U-Nets) — upsample latent maps to pixel space.
  • Semantic segmentation (U-Net decoder) — recover input resolution after a contractive encoder.
  • Audio synthesis — wavenet-like architectures.

Mathematically, transposed conv shares its weight tensor with regular conv but applies it in the “reverse” direction (the transpose of the matrix you’d write down for the regular conv).

What stride means here (different!)

For regular conv, stride=2 makes the output smaller (downsample). For transposed conv, stride=2 makes the output larger (upsample by 2×).

With padding="SAME" and stride=s:

  • Regular conv: H_out = ceil(H_in / s).
  • Transposed conv: H_out = H_in * s.

Common pattern in U-Nets: ConvTranspose with stride=2 doubles spatial size at each decoder stage.

API: jax.lax.conv_transpose

JAX has a dedicated primitive: jax.lax.conv_transpose.

y = jax.lax.conv_transpose(
    x[None, ...],                           # add batch dim
    kernel,                                 # WIO layout (1-D) or HWIO (2-D)
    strides=(stride,),
    padding="SAME",
    dimension_numbers=("NWC", "WIO", "NWC"),
)

Note the kwarg name is strides= (not window_strides=) — different from conv_general_dilated. Annoying but consistent within JAX.

Worked 1-D example

x = jnp.array([1., 2., 3., 4.])[:, None]   # (4, 1)
kernel = jnp.ones((3, 1, 1))               # WIO: w=3, in=1, out=1

y = jax.lax.conv_transpose(
    x[None, ...],
    kernel,
    strides=(2,),
    padding="SAME",
    dimension_numbers=("NWC", "WIO", "NWC"),
)
# y shape: (1, 8, 1) — input length 4, stride 2 → output length 8.

Why “transposed” and not “deconvolution”?

A true deconvolution would invert the convolution operation — you’d need to solve a linear system. Transposed conv shares the weight matrix’s transpose with the forward conv, but does NOT invert it. The output is a new feature map, not the original input recovered.

The “transposed” name comes from the matrix view: if you express conv as y = W @ x with W a sparse Toeplitz matrix, then transposed conv is x' = W.T @ y — same W, transposed.

Common pitfalls

  • Wrong kwarg name: strides= for conv_transpose, window_strides= for conv_general_dilated. Mix them up and you’ll get a TypeError.
  • Expecting “true deconvolution”: transposed conv is NOT an inverse. Don’t expect conv(transpose_conv(x)) == x.
  • Same kernel layout (WIO/HWIO) but interpretation flips: in transposed conv, the I dim is the input to the transposed op (= output of the forward conv) and O is the output. JAX’s API handles this automatically when you say “transpose”; you just need to know kernel[..., I, O] matches the layout.

Problem

Implement MyTransposedConv(features, kernel_size, stride) for 1-D inputs:

  1. Kernel shape (kernel_size, x.shape[-1], features), init lecun_normal().
  2. Bias shape (features,), init zeros.
  3. Add batch dim, call jax.lax.conv_transpose with strides=(stride,), padding="SAME", dimension_numbers=("NWC", "WIO", "NWC").
  4. Drop batch dim, add bias.
  5. Return .reshape(-1).

Inputs:

  • seed, features, kernel_size, stride: floats (cast to int).
  • x: 2-D (L, C_in).

Output: 1-D flattened.

Hints

flax conv-transpose upsample

Sign in to attempt this problem and view the solution.