medium primitives

NNX Implement Dense

Why this matters

Dense (a.k.a. Linear, fully-connected) is the most-used layer in deep learning. Reimplementing it in nnx is the cleanest way to internalize the layer-as-class shape that the rest of the track depends on. After this problem, every other layer is a small variation on the same template — different math in __call__, sometimes more parameters, occasionally non-trainable state — but always the same skeleton.

Compare with Linen: in Linen you’d subclass nn.Module, declare features: int as a dataclass field, and use self.param(...) inside a @nn.compact method to allocate weights lazily on the first call. Then model.init(key, x) runs the forward once to materialize them, returning a params dict you’d carry around. In nnx the parameters are just attributes of the module; construction allocates them; calling runs the forward.

No init, no apply. Build, then call.

API: parameters as attributes

class MyDense(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(
            jax.random.normal(key, (in_features, out_features))
            * (1.0 / jnp.sqrt(in_features))
        )
        self.bias = nnx.Param(jnp.zeros((out_features,)))

    def __call__(self, x):
        return x @ self.kernel + self.bias

Two parameters, both nnx.Param-wrapped:

  • kernel: shape (in_features, out_features), lecun-normal init — normal * (1 / sqrt(in_features)). Lecun-normal is the default for Linen’s nn.Dense and what most reference implementations use.
  • bias: shape (out_features,), zero init. Bias-zero is also the Linen default.

The forward is the textbook formula: x @ W + b. With x shape (in_features,) and kernel shape (in_features, out_features), the output is shape (out_features,).

Worked example

rngs = nnx.Rngs(0)
model = MyDense(in_features=3, out_features=4, rngs=rngs)
print(model.kernel.value.shape)            # (3, 4)
print(model.bias.value.shape)              # (4,)
x = jnp.array([1.0, 2.0, 3.0])
y = model(x)                                # shape (4,)

Now compare to Linen:

# Linen — for contrast.
class MyDense(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x):
        kernel = self.param(
            "kernel", nn.initializers.lecun_normal(),
            (x.shape[-1], self.features))
        bias = self.param("bias", nn.initializers.zeros, (self.features,))
        return x @ kernel + bias

model = MyDense(features=4)
params = model.init(jax.random.PRNGKey(0), x)
y = model.apply(params, x)

Same math, but in Linen the params dict is external and you have to thread it through apply. In nnx the model owns the params and you just call it.

Common pitfalls

  • Forgetting nnx.Param wrapper. self.kernel = jax.random.normal(...) makes the kernel a static attribute; nnx.split, nnx.state, and optimizers will not see it. Always wrap trainable arrays in nnx.Param.
  • Wrong init scale. jax.random.normal alone has stddev 1.0; the lecun-normal trick is to scale by 1 / sqrt(fan_in) so activations don’t explode through deep stacks.
  • Bias as (in_features,). Bias is added to the output, so its shape must match out_features.
  • features passed as float. The harness sends numeric inputs as floats — cast shape arguments to int before using them.

Problem

Write dense_forward(seed, x, features):

  1. Define MyDense(nnx.Module) with two trainable parameters:
    • self.kernel = nnx.Param(...) shape (in_features, out_features), init jax.random.normal(key, ...) * (1 / sqrt(in_features)).
    • self.bias = nnx.Param(jnp.zeros((out_features,))).
  2. __call__(self, x) returns x @ self.kernel + self.bias.
  3. Build nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1], out_features=int(features)), return the forward output.

Inputs:

  • seed: int (passed as float — cast to int).
  • x: 1-D JAX array.
  • features: int (passed as float — cast to int).

Output: 1-D array of length features.

Hints

flax nnx dense reimplementation

Sign in to attempt this problem and view the solution.