We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nn.Sequential Composition
Why this matters
Many real networks are just a stack of layers applied one after another:
x β Dense β ReLU β Dense β ReLU β Dense β y
Writing this out explicitly with repeated temporaries (x = Dense(x); x = relu(x))
works, but is verbose for long chains. Flax provides nn.Sequential to express
exactly this pattern β a linear chain of callables applied in order.
This problem teaches nn.Sequential and, more importantly, a key design principle
of Flax: pure callables (not just Modules) can appear anywhere in the forward
pass. jax.nn.relu is a plain Python function, not a Flax Module, but it works
perfectly as a Sequential item because it accepts and returns a single array.
nn.Sequential API
nn.Sequential([layer1, fn1, layer2, fn2, ...])
-
Each item is called in order:
y = item(x). -
Items can be Flax Modules (
nn.Dense,nn.LayerNorm, etc.) or any Python callable that maps a single array to a single array (jax.nn.relu,jnp.tanh, a lambda, etc.). -
nn.Sequentialitself is a Flax Module β it participates in the usualinit/applyround-trip and its parameter dict nests the parameters of all sub-Modules.
Worked 3-layer MLP example
import jax
import jax.numpy as jnp
import flax.linen as nn
class SeqMLP(nn.Module):
hidden: int
out: int
@nn.compact
def __call__(self, x):
return nn.Sequential([
nn.Dense(features=self.hidden),
jax.nn.relu,
nn.Dense(features=self.hidden),
jax.nn.relu,
nn.Dense(features=self.out),
])(x)
model = SeqMLP(hidden=4, out=2)
key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 0.0, 0.0])
params = model.init(key, x)
y = model.apply(params, x)
print(y.shape) # (2,)
# Inspect parameter structure:
print(jax.tree_util.tree_map(jnp.shape, params))
# {'params': {'Sequential_0': {
# 'layers_0': {'bias': (4,), 'kernel': (3, 4)},
# 'layers_2': {'bias': (4,), 'kernel': (4, 4)},
# 'layers_4': {'bias': (2,), 'kernel': (4, 2)},
# }}}
Note: the pure functions (jax.nn.relu at positions 1 and 3) have no parameters
and donβt appear in the params dict. Only the Dense layers at positions 0, 2, 4
appear as layers_0, layers_2, layers_4.
Comparing Sequential vs explicit code
These two are exactly equivalent:
# Explicit version
@nn.compact
def __call__(self, x):
x = nn.Dense(self.hidden)(x)
x = jax.nn.relu(x)
x = nn.Dense(self.hidden)(x)
x = jax.nn.relu(x)
x = nn.Dense(self.out)(x)
return x
# Sequential version
@nn.compact
def __call__(self, x):
return nn.Sequential([
nn.Dense(self.hidden),
jax.nn.relu,
nn.Dense(self.hidden),
jax.nn.relu,
nn.Dense(self.out),
])(x)
The outputs will differ numerically because Sequential wraps the layers in a
sub-Module (Sequential_0), which changes the param naming path, which changes
the RNG fold-in during init. But both correctly implement the same MLP topology.
When to use Sequential vs explicit
Use Sequential when:
- The entire forward pass is one linear chain.
- You want clean, compact code for simple MLPs.
- Youβre building a pipeline of preprocessing + model layers.
Use explicit code when:
- You have skip connections, residuals, or any non-linear topology.
- You need access to intermediate activations.
- You have layers with multiple inputs/outputs (attention, etc.).
nn.Sequential is for single-input, single-output chains only. If any layer
in the chain takes two inputs or returns two outputs, it wonβt work.
Common pitfalls
-
Multi-input layers in Sequential: putting an attention layer (which needs
query, key, value separately) into
nn.Sequentialwill fail at apply time because Sequential always callslayer(x)with a single argument. -
Mutable state in Sequential items: if you put a BatchNorm (which has
mutable running statistics) in Sequential, you need to handle the
use_running_averageflag. For inference, passmodel.apply(params, x, train=False). -
Forgetting parentheses on the Sequential call:
nn.Sequential([...])(x)β you construct the Sequential object first, then call it onx.
Problem
Build a 3-layer MLP using nn.Sequential:
Dense(hidden) β ReLU β Dense(hidden) β ReLU β Dense(out).
Wrap it in a @nn.compact module. Init and apply.
Inputs:
-
seed: int (passed as float β cast to int). -
x: 1-D JAX array. -
hidden: int (cast to int) β width of both hidden layers. -
out: int (cast to int) β number of output features.
Output: 1-D array of length out.
Hints
Sign in to attempt this problem and view the solution.