medium primitives

Bidirectional RNN

Why this matters

A vanilla RNN consumes a sequence left-to-right. At time t the hidden state has only seen x_1 ... x_t — the future is invisible. For tasks where you want each output position to depend on the WHOLE sequence (named entity recognition, POS tagging, masked language modeling pre-Transformer), this is a problem.

The fix is the bidirectional RNN: run TWO independent RNNs, one forward, one backward, and concatenate their per-step outputs. Each position now has both “everything that came before” and “everything that comes after” baked into its representation.

BiLSTMs and BiGRUs were the workhorse for sequence-labeling tasks until Transformers (which see the whole sequence by default) took over. They still appear in:

  • Speech recognition (look-ahead matters).
  • Off-line tagging where the whole sentence is available.
  • Encoder side of seq2seq models when latency isn’t critical.

Architecture

Given input x of shape (T, D):

fwd[t] = GRUCell(x[1..t])           # left-to-right hidden state
bwd[t] = GRUCell(x[T..t])           # right-to-left hidden state
out[t] = concat(fwd[t], bwd[t])     # length 2H

Each timestep gets a 2H representation. The forward and backward RNNs have INDEPENDENT parameters — the backward RNN learns its own gates on a reversed view of the sequence.

Reversing in JAX

The “backward RNN” is just the same RNN type running on a REVERSED input, then the outputs are reversed again so they line up with the original positions:

x_b   = x[None, ...]                # add batch dim: (1, T, D)
fwd_h = rnn_fwd(x_b)                # (1, T, H)
bwd_h = rnn_bwd(x_b[:, ::-1, :])[:, ::-1, :]  # (1, T, H), aligned
out   = jnp.concatenate([fwd_h, bwd_h], axis=-1)  # (1, T, 2H)

Why two reverses? After running on reversed input, the result is itself “in reverse order.” Reverse it back so position t of the output matches position t of the input.

Flax’s nn.RNN

flax.linen.RNN(cell, ...) wraps a single-step cell (e.g., nn.GRUCell(features=H)) into a sequence-level layer that sweeps through the time dim automatically. It expects a leading BATCH axis, so add x[None, ...] before calling.

Each RNN(GRUCell(H)) instance has its own params. Create TWO of them — one for forward, one for backward.

Worked walk-through

T=3, D=2, H=4, x=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]:

  1. x_b = (1, 3, 2) — batch axis added.
  2. fwd = rnn_fwd(x_b)(1, 3, 4).
  3. rev_in = x_b[:, ::-1, :] → still (1, 3, 2) but order flipped: [[0.5, 0.6], [0.3, 0.4], [0.1, 0.2]].
  4. rev_out = rnn_bwd(rev_in)(1, 3, 4).
  5. bwd = rev_out[:, ::-1, :](1, 3, 4), re-aligned to original positions.
  6. concat([fwd, bwd], axis=-1)(1, 3, 8).
  7. Drop batch → (3, 8) → flatten to (24,).

Common pitfalls

  • Sharing one RNN for both directions: the forward and backward RNNs are SEPARATE modules with separate weights.
  • Forgetting to re-reverse the backward output: leaves the backward features mis-aligned with their positions — silently wrong.
  • Concatenating along the WRONG axis: must be channel / feature axis (axis=-1), not time.
  • Forgetting the batch axis: nn.RNN expects (B, T, D). x[None, ...] adds it.

Problem

Implement birnn_forward(seed, x, hidden):

  1. BiRNN(nn.Module) with hidden field.
  2. Inside: build two nn.RNN(nn.GRUCell(features=hidden)) instances.
  3. Run forward RNN on x[None, ...]. Run backward RNN on the reversed input, then reverse the output again.
  4. Concatenate along axis=-1. Drop batch dim. Flatten.

Inputs:

  • seed: int.
  • x: 2-D (T, D).
  • hidden: int H.

Output: 1-D, length T * 2H.

Hints

flax birnn rnn gru

Sign in to attempt this problem and view the solution.