medium primitives

Module with @nn.compact

Why this matters

Flax (Linen API) is THE production library for building neural networks in JAX. Where PyTorch has nn.Module (with stateful, mutable instances), Flax has flax.linen.Module β€” but params live OUTSIDE the module instance. This is fundamental: Flax modules are pure blueprints. They define structure; params are produced by calling model.init(rng, *example_inputs) and threaded through every forward call via model.apply(params, *inputs).

The @nn.compact decorator is the most common way to define a Module: instead of a separate setup() method (which declares submodules) and __call__ (which uses them), @nn.compact lets you declare and call submodules inline in __call__. Submodules are auto-named by Python locals and declaration order. Concise; the recommended starting point for virtually all Flax code.

This problem teaches the core init/apply round-trip: build a Module β†’ init params β†’ apply params on input β†’ return output.

Motivation: why not just store params in __init__?

PyTorch does this: self.linear = nn.Linear(in, out) stores both the layer structure and its weights in the same object. This works fine for single- device, eager-mode code β€” but breaks under JAX’s functional model. JAX functions must be pure: same inputs β†’ same outputs, no hidden state. A PyTorch-style mutable module is not a pure function.

Flax’s solution: separate the blueprint (the Module class) from the weights (the params dict). The Module never holds weights; init traces the forward pass to build the weights as a plain pytree dict, and apply runs forward with those weights as an explicit argument.

model = DenseCompact(features=8)      # blueprint β€” no weights
params = model.init(rng, x)            # weights β€” plain dict
y = model.apply(params, x)             # forward β€” no mutation

This makes modules trivially serializable, composable with jit/vmap/grad, and easy to checkpoint.

Anatomy of @nn.compact

class DenseCompact(nn.Module):
    features: int              # Python dataclass field β€” set at construction

    @nn.compact
    def __call__(self, x):
        return nn.Dense(features=self.features)(x)   # declare + call inline

The features attribute is a hyperparameter, not a weight. It’s set when you build the module: DenseCompact(features=8). The actual weight matrix (the kernel and bias of the Dense layer) is created lazily the first time init is called with a concrete input shape.

Worked mini-example

import jax
import jax.numpy as jnp
import flax.linen as nn

class TinyModel(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.features)(x)
        return jax.nn.relu(x)

model = TinyModel(features=8)
key = jax.random.PRNGKey(0)
x = jnp.ones((16,))                           # 1-D input: 16 features
params = model.init(key, x)                    # build params
y = model.apply(params, x)                     # forward
print(jax.tree_util.tree_map(jnp.shape, params))
# {'params': {'Dense_0': {'kernel': (16, 8), 'bias': (8,)}}}

Notice the naming: the first nn.Dense in a compact module is auto-named Dense_0, the second Dense_1, and so on.

Common pitfalls

  • init requires a real example input: input shapes are inferred from the dummy input passed to init. If the shape is wrong, init will raise a shape error (not a helpful message). Always pass a representative input.
  • apply does NOT mutate the model: the params dict is never updated in-place. apply purely returns the forward-pass result. To update params (e.g. in training), compute gradients and use an Optax optimizer.
  • Params live in a dict, not on the module: don’t try model.kernel. Access weights via params["params"]["Dense_0"]["kernel"].
  • Multiple Dense instances in compact: they get auto-named Dense_0, Dense_1, … in declaration order. To assign custom names, pass name= to the layer constructor or use setup() instead (see the next problem).
  • PRNGKey vs int seed: jax.random.PRNGKey(seed) takes an integer; the problem passes seed as a float β€” cast it to int first.

Problem

Build a single-layer Flax Module using @nn.compact that wraps one nn.Dense layer. Init it with jax.random.PRNGKey(seed), apply it to x, and return the result.

Inputs:

  • seed: int (passed as float β€” cast to int).
  • x: 1-D JAX array.
  • features: int (number of output features β€” cast to int).

Output: 1-D array of length features β€” the Dense layer output.

Hints

flax module compact

Sign in to attempt this problem and view the solution.