We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
masked: Apply WD Only To Certain Params
Why this matters
Weight decay should typically be applied to weight matrices but NOT to
bias parameters, layer-norm scales, or embedding tables β applying WD to
those can hurt convergence. optax.masked is the standard primitive for
per-parameter-group transforms; this problem demonstrates the underlying
logic using element-wise masking on a 1-D array.
The recipe
optax.masked(transform, mask) applies transform only where
mask = True in the pytree. For a 1-D array, the equivalent
element-wise computation is:
# Add WD only to masked positions, then apply SGD
bool_mask = mask > 0.5
wd_grads = jnp.where(bool_mask, wd * params, 0.0)
combined_grads = grads + wd_grads
optimizer = optax.sgd(lr)
...
Note: optax.masked operates on pytree nodes, not individual
elements of a 1-D array. For element-level selection, jnp.where
replicates the semantics cleanly.
Common pitfalls
-
Passing a JAX array directly as the mask to
optax.maskedraises ambiguous-bool errors because JAX arrays canβt be used in Pythonifstatements. -
Weight decay adds
wd * paramsto the gradient before the optimizer step β it modifies the gradient, not the parameter directly.
Inputs
-
params: 1-D JAX array. -
grads: 1-D JAX array of gradients. -
lr: scalar β SGD learning rate. -
wd: scalar β weight decay coefficient. -
mask: 1-D JAX array of 0.0/1.0 β 1.0 means apply WD to that element.
Output
1-D array β params after one optimizer step (WD only where mask=1).
Hints
Sign in to attempt this problem and view the solution.