medium primitives

Data Collator with Padding

Implement batched data collation with padding โ€” the preprocessing step that converts a list of variable-length integer sequences into a uniform 2-D batch ready for model forward passes.

Why padding is needed:

Neural network operations (matrix multiply, attention, convolutions) require every sample in a batch to have the same shape. Sequences in the wild have different lengths, so we find the longest sequence in the batch and right-pad all shorter sequences up to that length with a sentinel pad_id.

Right-padding vs left-padding:

Right-padding is standard for forward passes on both RNNs and transformers. With RNNs, the model processes real tokens first and sees padding only at the end where it can be masked. With transformers, causal masks already prevent attending to future positions, and additive padding masks zero out attention to pad tokens. Left-padding is used when generation must end at the same position for all sequences in a batch (e.g. some inference engines). For training, right-pad.

Algorithm:

  1. Guard: if sequences is empty, return [] immediately. (Calling max on an empty iterable raises ValueError.)
  2. Find max_len = max(len(s) for s in sequences).
  3. For each sequence s, append [pad_id] * (max_len - len(s)) to pad it to max_len. If s is already the longest, nothing is appended.
  4. Return the list of padded sequences.

Edge cases:

  • max_len = 0 when every sequence is empty โ€” s + [] * 0 = [], so each sequence remains []. Correct.
  • A single sequence โ€” no padding needed, returned as-is.
  • All sequences the same length โ€” no padding added.

Return type โ€” list-of-list, not tensor:

Returning list[list[int]] keeps the helper framework-agnostic and avoids committing to a specific dtype. If you need a tensor, call torch.tensor(pad_batch(seqs)) or jnp.array(pad_batch(seqs)) yourself.

Function signature:

def pad_batch(sequences, pad_id=0) -> list[list[int]]

Do NOT use torch.nn.utils.rnn.pad_sequence โ€” build the output explicitly.

Hints

data padding preprocessing

Sign in to attempt this problem and view the solution.