hard primitives

masked: Apply WD Only To Certain Params

Why this matters

Weight decay should typically be applied to weight matrices but NOT to bias parameters, layer-norm scales, or embedding tables β€” applying WD to those can hurt convergence. optax.masked is the standard primitive for per-parameter-group transforms; this problem demonstrates the underlying logic using element-wise masking on a 1-D array.

The recipe

optax.masked(transform, mask) applies transform only where mask = True in the pytree. For a 1-D array, the equivalent element-wise computation is:

# Add WD only to masked positions, then apply SGD
bool_mask = mask > 0.5
wd_grads = jnp.where(bool_mask, wd * params, 0.0)
combined_grads = grads + wd_grads
optimizer = optax.sgd(lr)
...

Note: optax.masked operates on pytree nodes, not individual elements of a 1-D array. For element-level selection, jnp.where replicates the semantics cleanly.

Common pitfalls

  • Passing a JAX array directly as the mask to optax.masked raises ambiguous-bool errors because JAX arrays can’t be used in Python if statements.
  • Weight decay adds wd * params to the gradient before the optimizer step β€” it modifies the gradient, not the parameter directly.

Inputs

  • params: 1-D JAX array.
  • grads: 1-D JAX array of gradients.
  • lr: scalar β€” SGD learning rate.
  • wd: scalar β€” weight decay coefficient.
  • mask: 1-D JAX array of 0.0/1.0 β€” 1.0 means apply WD to that element.

Output

1-D array β€” params after one optimizer step (WD only where mask=1).

Hints

optax masked weight-decay

Sign in to attempt this problem and view the solution.