We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement Dense
Why this matters
Dense (a.k.a. Linear, fully-connected) is the most-used layer in deep
learning. Reimplementing it in nnx is the cleanest way to internalize
the layer-as-class shape that the rest of the track depends on. After
this problem, every other layer is a small variation on the same
template — different math in __call__, sometimes more parameters,
occasionally non-trainable state — but always the same skeleton.
Compare with Linen: in Linen you’d subclass nn.Module, declare
features: int as a dataclass field, and use self.param(...) inside
a @nn.compact method to allocate weights lazily on the first call.
Then model.init(key, x) runs the forward once to materialize them,
returning a params dict you’d carry around. In nnx the parameters
are just attributes of the module; construction allocates them; calling
runs the forward.
No init, no apply. Build, then call.
API: parameters as attributes
class MyDense(nnx.Module):
def __init__(self, in_features, out_features, rngs):
key = rngs.params()
self.kernel = nnx.Param(
jax.random.normal(key, (in_features, out_features))
* (1.0 / jnp.sqrt(in_features))
)
self.bias = nnx.Param(jnp.zeros((out_features,)))
def __call__(self, x):
return x @ self.kernel + self.bias
Two parameters, both nnx.Param-wrapped:
-
kernel: shape(in_features, out_features), lecun-normal init —normal * (1 / sqrt(in_features)). Lecun-normal is the default for Linen’snn.Denseand what most reference implementations use. -
bias: shape(out_features,), zero init. Bias-zero is also the Linen default.
The forward is the textbook formula: x @ W + b. With x shape
(in_features,) and kernel shape (in_features, out_features),
the output is shape (out_features,).
Worked example
rngs = nnx.Rngs(0)
model = MyDense(in_features=3, out_features=4, rngs=rngs)
print(model.kernel.value.shape) # (3, 4)
print(model.bias.value.shape) # (4,)
x = jnp.array([1.0, 2.0, 3.0])
y = model(x) # shape (4,)
Now compare to Linen:
# Linen — for contrast.
class MyDense(nn.Module):
features: int
@nn.compact
def __call__(self, x):
kernel = self.param(
"kernel", nn.initializers.lecun_normal(),
(x.shape[-1], self.features))
bias = self.param("bias", nn.initializers.zeros, (self.features,))
return x @ kernel + bias
model = MyDense(features=4)
params = model.init(jax.random.PRNGKey(0), x)
y = model.apply(params, x)
Same math, but in Linen the params dict is external and you have to
thread it through apply. In nnx the model owns the params and you
just call it.
Common pitfalls
-
Forgetting
nnx.Paramwrapper.self.kernel = jax.random.normal(...)makes the kernel a static attribute;nnx.split,nnx.state, and optimizers will not see it. Always wrap trainable arrays innnx.Param. -
Wrong init scale.
jax.random.normalalone has stddev 1.0; the lecun-normal trick is to scale by1 / sqrt(fan_in)so activations don’t explode through deep stacks. -
Bias as
(in_features,). Bias is added to the output, so its shape must matchout_features. -
featurespassed as float. The harness sends numeric inputs as floats — cast shape arguments tointbefore using them.
Problem
Write dense_forward(seed, x, features):
-
Define
MyDense(nnx.Module)with two trainable parameters:-
self.kernel = nnx.Param(...)shape(in_features, out_features), initjax.random.normal(key, ...) * (1 / sqrt(in_features)). -
self.bias = nnx.Param(jnp.zeros((out_features,))).
-
-
__call__(self, x)returnsx @ self.kernel + self.bias. -
Build
nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1],out_features=int(features)), return the forward output.
Inputs:
-
seed: int (passed as float — cast to int). -
x: 1-D JAX array. -
features: int (passed as float — cast to int).
Output: 1-D array of length features.
Hints
Sign in to attempt this problem and view the solution.