We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Bidirectional RNN
Why this matters
A vanilla RNN consumes a sequence left-to-right. At time t the
hidden state has only seen x_1 ... x_t — the future is invisible.
For tasks where you want each output position to depend on the
WHOLE sequence (named entity recognition, POS tagging, masked
language modeling pre-Transformer), this is a problem.
The fix is the bidirectional RNN: run TWO independent RNNs, one forward, one backward, and concatenate their per-step outputs. Each position now has both “everything that came before” and “everything that comes after” baked into its representation.
BiLSTMs and BiGRUs were the workhorse for sequence-labeling tasks until Transformers (which see the whole sequence by default) took over. They still appear in:
- Speech recognition (look-ahead matters).
- Off-line tagging where the whole sentence is available.
- Encoder side of seq2seq models when latency isn’t critical.
Architecture
Given input x of shape (T, D):
fwd[t] = GRUCell(x[1..t]) # left-to-right hidden state
bwd[t] = GRUCell(x[T..t]) # right-to-left hidden state
out[t] = concat(fwd[t], bwd[t]) # length 2H
Each timestep gets a 2H representation. The forward and
backward RNNs have INDEPENDENT parameters — the backward RNN
learns its own gates on a reversed view of the sequence.
Reversing in JAX
The “backward RNN” is just the same RNN type running on a REVERSED input, then the outputs are reversed again so they line up with the original positions:
x_b = x[None, ...] # add batch dim: (1, T, D)
fwd_h = rnn_fwd(x_b) # (1, T, H)
bwd_h = rnn_bwd(x_b[:, ::-1, :])[:, ::-1, :] # (1, T, H), aligned
out = jnp.concatenate([fwd_h, bwd_h], axis=-1) # (1, T, 2H)
Why two reverses? After running on reversed input, the result
is itself “in reverse order.” Reverse it back so position t of
the output matches position t of the input.
Flax’s nn.RNN
flax.linen.RNN(cell, ...) wraps a single-step cell (e.g.,
nn.GRUCell(features=H)) into a sequence-level layer that
sweeps through the time dim automatically. It expects a leading
BATCH axis, so add x[None, ...] before calling.
Each RNN(GRUCell(H)) instance has its own params. Create TWO of
them — one for forward, one for backward.
Worked walk-through
T=3, D=2, H=4, x=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]:
-
x_b = (1, 3, 2)— batch axis added. -
fwd = rnn_fwd(x_b)→(1, 3, 4). -
rev_in = x_b[:, ::-1, :]→ still(1, 3, 2)but order flipped:[[0.5, 0.6], [0.3, 0.4], [0.1, 0.2]]. -
rev_out = rnn_bwd(rev_in)→(1, 3, 4). -
bwd = rev_out[:, ::-1, :]→(1, 3, 4), re-aligned to original positions. -
concat([fwd, bwd], axis=-1)→(1, 3, 8). -
Drop batch →
(3, 8)→ flatten to(24,).
Common pitfalls
- Sharing one RNN for both directions: the forward and backward RNNs are SEPARATE modules with separate weights.
- Forgetting to re-reverse the backward output: leaves the backward features mis-aligned with their positions — silently wrong.
-
Concatenating along the WRONG axis: must be channel /
feature axis (
axis=-1), not time. -
Forgetting the batch axis:
nn.RNNexpects(B, T, D).x[None, ...]adds it.
Problem
Implement birnn_forward(seed, x, hidden):
-
BiRNN(nn.Module)withhiddenfield. -
Inside: build two
nn.RNN(nn.GRUCell(features=hidden))instances. -
Run forward RNN on
x[None, ...]. Run backward RNN on the reversed input, then reverse the output again. -
Concatenate along
axis=-1. Drop batch dim. Flatten.
Inputs:
-
seed: int. -
x: 2-D(T, D). -
hidden: int H.
Output: 1-D, length T * 2H.
Hints
Sign in to attempt this problem and view the solution.