medium primitives

nn.Sequential Composition

Why this matters

Many real networks are just a stack of layers applied one after another:

x β†’ Dense β†’ ReLU β†’ Dense β†’ ReLU β†’ Dense β†’ y

Writing this out explicitly with repeated temporaries (x = Dense(x); x = relu(x)) works, but is verbose for long chains. Flax provides nn.Sequential to express exactly this pattern β€” a linear chain of callables applied in order.

This problem teaches nn.Sequential and, more importantly, a key design principle of Flax: pure callables (not just Modules) can appear anywhere in the forward pass. jax.nn.relu is a plain Python function, not a Flax Module, but it works perfectly as a Sequential item because it accepts and returns a single array.

nn.Sequential API

nn.Sequential([layer1, fn1, layer2, fn2, ...])
  • Each item is called in order: y = item(x).
  • Items can be Flax Modules (nn.Dense, nn.LayerNorm, etc.) or any Python callable that maps a single array to a single array (jax.nn.relu, jnp.tanh, a lambda, etc.).
  • nn.Sequential itself is a Flax Module β€” it participates in the usual init / apply round-trip and its parameter dict nests the parameters of all sub-Modules.

Worked 3-layer MLP example

import jax
import jax.numpy as jnp
import flax.linen as nn

class SeqMLP(nn.Module):
    hidden: int
    out: int

    @nn.compact
    def __call__(self, x):
        return nn.Sequential([
            nn.Dense(features=self.hidden),
            jax.nn.relu,
            nn.Dense(features=self.hidden),
            jax.nn.relu,
            nn.Dense(features=self.out),
        ])(x)

model = SeqMLP(hidden=4, out=2)
key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 0.0, 0.0])

params = model.init(key, x)
y = model.apply(params, x)
print(y.shape)  # (2,)

# Inspect parameter structure:
print(jax.tree_util.tree_map(jnp.shape, params))
# {'params': {'Sequential_0': {
#     'layers_0': {'bias': (4,), 'kernel': (3, 4)},
#     'layers_2': {'bias': (4,), 'kernel': (4, 4)},
#     'layers_4': {'bias': (2,), 'kernel': (4, 2)},
# }}}

Note: the pure functions (jax.nn.relu at positions 1 and 3) have no parameters and don’t appear in the params dict. Only the Dense layers at positions 0, 2, 4 appear as layers_0, layers_2, layers_4.

Comparing Sequential vs explicit code

These two are exactly equivalent:

# Explicit version
@nn.compact
def __call__(self, x):
    x = nn.Dense(self.hidden)(x)
    x = jax.nn.relu(x)
    x = nn.Dense(self.hidden)(x)
    x = jax.nn.relu(x)
    x = nn.Dense(self.out)(x)
    return x

# Sequential version
@nn.compact
def __call__(self, x):
    return nn.Sequential([
        nn.Dense(self.hidden),
        jax.nn.relu,
        nn.Dense(self.hidden),
        jax.nn.relu,
        nn.Dense(self.out),
    ])(x)

The outputs will differ numerically because Sequential wraps the layers in a sub-Module (Sequential_0), which changes the param naming path, which changes the RNG fold-in during init. But both correctly implement the same MLP topology.

When to use Sequential vs explicit

Use Sequential when:

  • The entire forward pass is one linear chain.
  • You want clean, compact code for simple MLPs.
  • You’re building a pipeline of preprocessing + model layers.

Use explicit code when:

  • You have skip connections, residuals, or any non-linear topology.
  • You need access to intermediate activations.
  • You have layers with multiple inputs/outputs (attention, etc.).

nn.Sequential is for single-input, single-output chains only. If any layer in the chain takes two inputs or returns two outputs, it won’t work.

Common pitfalls

  • Multi-input layers in Sequential: putting an attention layer (which needs query, key, value separately) into nn.Sequential will fail at apply time because Sequential always calls layer(x) with a single argument.
  • Mutable state in Sequential items: if you put a BatchNorm (which has mutable running statistics) in Sequential, you need to handle the use_running_average flag. For inference, pass model.apply(params, x, train=False).
  • Forgetting parentheses on the Sequential call: nn.Sequential([...])(x) β€” you construct the Sequential object first, then call it on x.

Problem

Build a 3-layer MLP using nn.Sequential: Dense(hidden) β†’ ReLU β†’ Dense(hidden) β†’ ReLU β†’ Dense(out). Wrap it in a @nn.compact module. Init and apply.

Inputs:

  • seed: int (passed as float β€” cast to int).
  • x: 1-D JAX array.
  • hidden: int (cast to int) β€” width of both hidden layers.
  • out: int (cast to int) β€” number of output features.

Output: 1-D array of length out.

Hints

flax sequential compose

Sign in to attempt this problem and view the solution.