medium primitives

init() and apply() Round-Trip

Why this matters

The init / apply round-trip is the central workflow in Flax. Every Flax program follows this pattern:

  1. Build the blueprint: model = MyModule(hyperparams=...) β€” no weights yet.
  2. Init to get weights: params = model.init(rng, *example_inputs) β€” traces the forward pass with a dummy input to determine shapes, draws random weights.
  3. Apply for forward pass: y = model.apply(params, *inputs) β€” runs forward with the given params.
  4. Repeat step 3 every inference/training step β€” weights never live β€œinside” the model.

This problem asks you to build a two-layer MLP with this pattern: Dense(hidden) β†’ ReLU β†’ Dense(out).

What init() actually does

model.init(rng, x) does three things:

  1. Shape inference: JAX traces through __call__ with abstract shapes from the example input. Each nn.Dense(features=N) layer reads the input shape to determine the kernel shape: (n_in, N).
  2. RNG allocation: each parameter has its own unique key derived from the top-level rng via the param name path. Two calls to init with the same rng produce the same weights.
  3. Returns a frozen dict: the output is {"params": {...}} β€” a nested dict mapping layer names β†’ {β€œkernel”: array, β€œbias”: array}.

For a two-layer MLP with x of shape (3,), hidden=4, out=2:

params["params"]["Dense_0"]["kernel"]   # shape (3, 4)
params["params"]["Dense_0"]["bias"]     # shape (4,)
params["params"]["Dense_1"]["kernel"]   # shape (4, 2)
params["params"]["Dense_1"]["bias"]     # shape (2,)

You can inspect all shapes at once:

jax.tree_util.tree_map(jnp.shape, params)
# {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)},
#             'Dense_1': {'bias': (2,), 'kernel': (4, 2)}}}

What apply() does

model.apply(params, x) is a pure function call: it runs the forward pass using the provided params. It does NOT:

  • Mutate params.
  • Allocate new RNG keys (unless you pass rngs={"dropout": key}).
  • Track gradients (use jax.grad around the whole apply call for that).

This purity is crucial: it means apply is trivially jit-able, vmap-able, and differentiable.

Two-layer MLP forward pass, worked

import jax
import jax.numpy as jnp
import flax.linen as nn

class TwoLayer(nn.Module):
    hidden: int
    out: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden)(x)   # Dense_0
        x = jax.nn.relu(x)
        x = nn.Dense(features=self.out)(x)      # Dense_1
        return x

model = TwoLayer(hidden=4, out=2)
key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 0.5, 0.0])              # input: 3 features

params = model.init(key, x)
y = model.apply(params, x)
print(y.shape)  # (2,)

# Inspect shapes:
print(jax.tree_util.tree_map(jnp.shape, params))
# {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)},
#             'Dense_1': {'bias': (2,), 'kernel': (4, 2)}}}

Using apply in a training loop

def loss_fn(params, x, y_true):
    y_pred = model.apply(params, x)
    return jnp.mean((y_pred - y_true) ** 2)

grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(params, x_batch, y_batch)
# update params with grads via optax...

apply is just a function β€” it composes with all of JAX’s transformations.

Common pitfalls

  • Passing the wrong shape to init: if x has shape (3,) during init but you later pass (3, 4) (a batch) to apply, the first Dense layer will get the correct kernel shape for (3,) input and will correctly broadcast over a batch (3, 4). But if the shapes are truly incompatible, JAX will error.
  • Returning model.init(...) instead of model.apply(...): init returns the params dict, not the forward result. You need apply to get the output.
  • Using deterministic=True: if your module has Dropout or BatchNorm (with use_running_average), pass model.apply(params, x, train=False) or similar keyword β€” the method signature controls this.

Problem

Build a two-layer MLP: Dense(hidden) β†’ ReLU β†’ Dense(out). Use @nn.compact. Init with PRNGKey(seed), apply to x, return the output.

Inputs:

  • seed: int (passed as float β€” cast to int).
  • x: 1-D JAX array.
  • hidden: int (cast to int) β€” width of the hidden layer.
  • out: int (cast to int) β€” number of output features.

Output: 1-D array of length out.

Hints

flax init apply

Sign in to attempt this problem and view the solution.