We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Module with @nn.compact
Why this matters
Flax (Linen API) is THE production library for building neural networks in JAX.
Where PyTorch has nn.Module (with stateful, mutable instances), Flax has
flax.linen.Module β but params live OUTSIDE the module instance. This is
fundamental: Flax modules are pure blueprints. They define structure; params
are produced by calling model.init(rng, *example_inputs) and threaded through
every forward call via model.apply(params, *inputs).
The @nn.compact decorator is the most common way to define a Module: instead of
a separate setup() method (which declares submodules) and __call__ (which uses
them), @nn.compact lets you declare and call submodules inline in __call__.
Submodules are auto-named by Python locals and declaration order. Concise; the
recommended starting point for virtually all Flax code.
This problem teaches the core init/apply round-trip: build a Module β init params β apply params on input β return output.
Motivation: why not just store params in __init__?
PyTorch does this: self.linear = nn.Linear(in, out) stores both the layer
structure and its weights in the same object. This works fine for single-
device, eager-mode code β but breaks under JAXβs functional model. JAX functions
must be pure: same inputs β same outputs, no hidden state. A PyTorch-style mutable
module is not a pure function.
Flaxβs solution: separate the blueprint (the Module class) from the weights
(the params dict). The Module never holds weights; init traces the forward pass
to build the weights as a plain pytree dict, and apply runs forward with those
weights as an explicit argument.
model = DenseCompact(features=8) # blueprint β no weights
params = model.init(rng, x) # weights β plain dict
y = model.apply(params, x) # forward β no mutation
This makes modules trivially serializable, composable with jit/vmap/grad,
and easy to checkpoint.
Anatomy of @nn.compact
class DenseCompact(nn.Module):
features: int # Python dataclass field β set at construction
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features)(x) # declare + call inline
The features attribute is a hyperparameter, not a weight. Itβs set when you
build the module: DenseCompact(features=8). The actual weight matrix (the kernel
and bias of the Dense layer) is created lazily the first time init is called with
a concrete input shape.
Worked mini-example
import jax
import jax.numpy as jnp
import flax.linen as nn
class TinyModel(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return jax.nn.relu(x)
model = TinyModel(features=8)
key = jax.random.PRNGKey(0)
x = jnp.ones((16,)) # 1-D input: 16 features
params = model.init(key, x) # build params
y = model.apply(params, x) # forward
print(jax.tree_util.tree_map(jnp.shape, params))
# {'params': {'Dense_0': {'kernel': (16, 8), 'bias': (8,)}}}
Notice the naming: the first nn.Dense in a compact module is auto-named
Dense_0, the second Dense_1, and so on.
Common pitfalls
-
initrequires a real example input: input shapes are inferred from the dummy input passed toinit. If the shape is wrong,initwill raise a shape error (not a helpful message). Always pass a representative input. -
applydoes NOT mutate the model: the params dict is never updated in-place.applypurely returns the forward-pass result. To update params (e.g. in training), compute gradients and use an Optax optimizer. -
Params live in a dict, not on the module: donβt try
model.kernel. Access weights viaparams["params"]["Dense_0"]["kernel"]. -
Multiple Dense instances in compact: they get auto-named
Dense_0,Dense_1, β¦ in declaration order. To assign custom names, passname=to the layer constructor or usesetup()instead (see the next problem). -
PRNGKey vs int seed:
jax.random.PRNGKey(seed)takes an integer; the problem passesseedas a float β cast it tointfirst.
Problem
Build a single-layer Flax Module using @nn.compact that wraps one nn.Dense
layer. Init it with jax.random.PRNGKey(seed), apply it to x, and return the
result.
Inputs:
-
seed: int (passed as float β cast to int). -
x: 1-D JAX array. -
features: int (number of output features β cast to int).
Output: 1-D array of length features β the Dense layer output.
Hints
Sign in to attempt this problem and view the solution.