We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
init() and apply() Round-Trip
Why this matters
The init / apply round-trip is the central workflow in Flax. Every Flax program
follows this pattern:
-
Build the blueprint:
model = MyModule(hyperparams=...)β no weights yet. -
Init to get weights:
params = model.init(rng, *example_inputs)β traces the forward pass with a dummy input to determine shapes, draws random weights. -
Apply for forward pass:
y = model.apply(params, *inputs)β runs forward with the given params. - Repeat step 3 every inference/training step β weights never live βinsideβ the model.
This problem asks you to build a two-layer MLP with this pattern:
Dense(hidden) β ReLU β Dense(out).
What init() actually does
model.init(rng, x) does three things:
-
Shape inference: JAX traces through
__call__with abstract shapes from the example input. Eachnn.Dense(features=N)layer reads the input shape to determine the kernel shape:(n_in, N). -
RNG allocation: each parameter has its own unique key derived from the
top-level
rngvia the param name path. Two calls toinitwith the samerngproduce the same weights. -
Returns a frozen dict: the output is
{"params": {...}}β a nested dict mapping layer names β {βkernelβ: array, βbiasβ: array}.
For a two-layer MLP with x of shape (3,), hidden=4, out=2:
params["params"]["Dense_0"]["kernel"] # shape (3, 4)
params["params"]["Dense_0"]["bias"] # shape (4,)
params["params"]["Dense_1"]["kernel"] # shape (4, 2)
params["params"]["Dense_1"]["bias"] # shape (2,)
You can inspect all shapes at once:
jax.tree_util.tree_map(jnp.shape, params)
# {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)},
# 'Dense_1': {'bias': (2,), 'kernel': (4, 2)}}}
What apply() does
model.apply(params, x) is a pure function call: it runs the forward pass using
the provided params. It does NOT:
-
Mutate
params. -
Allocate new RNG keys (unless you pass
rngs={"dropout": key}). -
Track gradients (use
jax.gradaround the whole apply call for that).
This purity is crucial: it means apply is trivially jit-able, vmap-able,
and differentiable.
Two-layer MLP forward pass, worked
import jax
import jax.numpy as jnp
import flax.linen as nn
class TwoLayer(nn.Module):
hidden: int
out: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.hidden)(x) # Dense_0
x = jax.nn.relu(x)
x = nn.Dense(features=self.out)(x) # Dense_1
return x
model = TwoLayer(hidden=4, out=2)
key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 0.5, 0.0]) # input: 3 features
params = model.init(key, x)
y = model.apply(params, x)
print(y.shape) # (2,)
# Inspect shapes:
print(jax.tree_util.tree_map(jnp.shape, params))
# {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)},
# 'Dense_1': {'bias': (2,), 'kernel': (4, 2)}}}
Using apply in a training loop
def loss_fn(params, x, y_true):
y_pred = model.apply(params, x)
return jnp.mean((y_pred - y_true) ** 2)
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(params, x_batch, y_batch)
# update params with grads via optax...
apply is just a function β it composes with all of JAXβs transformations.
Common pitfalls
-
Passing the wrong shape to init: if
xhas shape(3,)duringinitbut you later pass(3, 4)(a batch) toapply, the first Dense layer will get the correct kernel shape for(3,)input and will correctly broadcast over a batch(3, 4). But if the shapes are truly incompatible, JAX will error. -
Returning
model.init(...)instead ofmodel.apply(...):initreturns the params dict, not the forward result. You needapplyto get the output. -
Using
deterministic=True: if your module has Dropout or BatchNorm (withuse_running_average), passmodel.apply(params, x, train=False)or similar keyword β the method signature controls this.
Problem
Build a two-layer MLP: Dense(hidden) β ReLU β Dense(out). Use @nn.compact.
Init with PRNGKey(seed), apply to x, return the output.
Inputs:
-
seed: int (passed as float β cast to int). -
x: 1-D JAX array. -
hidden: int (cast to int) β width of the hidden layer. -
out: int (cast to int) β number of output features.
Output: 1-D array of length out.
Hints
Sign in to attempt this problem and view the solution.