Problem of the day: NNX Mutability Demo

Problems

Difficulty:
Category:
Tag:
Status Title Difficulty
Implement KL Divergence medium
MLM Eval โ€” Masked Accuracy medium
Matrix Multiplication easy
Hessian of a Quadratic easy
Implement Sigmoid easy
Implement Mean Squared Error easy
SwiGLU Activation medium
Implement Tanh easy
Implement Cosine Similarity easy
Vectorize with vmap medium
MLM Forward Pass medium
Flash Attention Score Computation hard
Adapter Layer medium
Transpose a Matrix easy
Greedy Decoding easy
Implement Group Norm medium
Bidirectional RNN medium
Causal Attention Mask easy
Implement Embedding Lookup easy
Training Loop easy
GLU Activation easy
Implement Conv2d medium
Concatenate Tensors easy
Ridge Regression easy
2-D FFT DC Component hard
Implement Positional Encoding hard
Logistic Regression From Scratch easy
Transposed Conv2d hard
Data Augmentation medium
Implement GroupNorm medium
Broadcasting Addition easy
Transformer Encoder Block hard
Multi-Class Classifier easy
Efficient Attention with Masking hard
Top-k Sampling medium
Element-wise Operations easy
Create a Tensor from a List easy
Scatter Add medium
Gradient Clipping easy
Triplet Loss medium
fori vs scan vs while medium
Accuracy โ€“ Multiclass easy
Binary Classification via lax.cond easy
Transformer Decoder Block hard
Grouped-Query Attention hard
Multi-Query Attention medium
AUC-ROC medium
Custom Activation with Gradient hard
Masked Fill medium
vmap with out_axes medium
Top-K Values medium
BLEU-1gram medium
Top-K Gating medium
ALiBi Position Bias medium
Relative Position Encoding medium
BPE Encode Text hard
BPE Merge Step medium
Tensor Indexing and Slicing easy
Parallel Map with vmap hard
int8 Dequantization easy
Nucleus (Top-P) Sampling medium
Classifier-Free Guidance medium
Feature Normalization Pipeline easy
Focal Loss medium
Confusion Matrix easy
Speculative Decoding Accept/Reject hard
Implement Softmax easy
Implement Leaky ReLU easy
Implement Binary Cross-Entropy Loss medium
Data Collator with Padding medium
DDPM Denoising Step hard
Mixture of Experts Routing hard
Implement Cross-Entropy Loss medium
DDPM Forward Noising medium
Prefix Tuning hard
DDPM Noise Schedule easy
Cumulative Sum easy
SVD Singular Values medium
Implement Momentum Update medium
Eval Loop with Metrics medium
Expected Calibration Error medium
Exponential Moving Average medium
Gradient Accumulation medium
Mini-Batch Training medium
Learned Absolute Position Embedding easy
Beam Search hard
Model Checkpointing easy
Stack Tensors easy
Perplexity from Logits easy
Implement Linear Layer medium
Tied Input/Output Embeddings medium
MLM Masking Strategy medium
Batched Matrix Multiply medium
Word Embedding Model medium
Implement Dropout medium
Einstein Summation medium
Implement One-Hot Encoding easy
Precision, Recall & F1 medium
Temperature Scaling easy
Polynomial Regression medium
Autoencoder medium
Ring All-Reduce hard
Implement Gradient Descent Step easy
GELU Activation easy
Subword Tokenizer: Greedy Longest-Prefix-Match medium
Linear Regression easy
Implement ReLU easy
Tokenize and Pad Batch medium
Weight Initialization easy
Reshape a Tensor easy
JIT Compile a Function medium
ViT with Mixup Augmentation medium
Contrastive Loss (InfoNCE) medium
VAE ELBO Loss medium
Depthwise Separable Convolution hard
Simple RNN Cell medium
VAE Reparameterization Trick easy
Rotary Position Embeddings hard
Binary Classifier easy
LoRA Update medium
Causal Self-Attention Block medium
Cross Attention medium
Distributed Training Step End-to-End hard
Gather Elements medium
Compute Jacobian hard
DDPM Sampler Loop hard
DDPM Training Step End-to-End medium
Simple GAN Generator hard
Learning Rate Scheduler medium
Mixed Precision Forward Pass hard
Multi-Head Attention Block hard
RMSNorm easy
Train Binary Classifier End-to-End medium
Nucleus Sampling Generation Loop hard
Label Smoothing easy
Two-Layer MLP easy
GRU Cell medium
Implement Scaled Dot-Product Attention hard
Implement Batch Normalization medium
Train Image Classifier End-to-End medium
Implement Max Pooling 1D medium
Train LoRA Adapter End-to-End medium
LSTM Cell medium
Train Multiclass Classifier with Early Stopping medium
Train RNN Language Model hard
Train VAE End-to-End medium
Sliding Window Attention medium
Train with Adam End-to-End medium
Implement Conv2D with Stride and Padding medium
Skip Connection Block medium
Transformer Block with RoPE hard
Implement 1D Convolution hard
Implement Adam Optimizer Step hard
Implement Average Pooling 1D medium
Causal LM Forward Pass medium
Sequence Classifier medium
Implement L2 Regularization easy
Compute Gradient medium
Functional Update with .at[].set() easy
Self-Attention Layer hard
LLaMA-Style Transformer Block hard
KV Cache for Autoregressive Decoding hard
Apply Along an Axis medium
Causal LM with KV Cache Generation hard
Implement Layer Normalization medium
Cross-Attention Block medium
Stochastic Depth medium
Encoder-Decoder Transformer Forward Pass hard
Squeeze-and-Excitation Block medium
Simple CNN medium
Encoder-Decoder Greedy Decode medium
Encoder-Decoder Beam Search hard
MLM Forward with Tied Output Head medium
Train ViT Classifier End-to-End hard
ViT Classification Forward hard
ViT Encoder Block medium
Patch Embedding with CLS Token medium
Train Causal LM Pretraining Step hard
Train Encoder-Decoder Seq2Seq Step hard
Train MLM Pretraining Step hard
Gradient over Pytree Params medium
Train Tiny GPT End-to-End hard
Deterministic Batch via vmap+split medium
Dtype Promotion Rules hard
Functional Update with .at[].add() easy
x^n via lax.fori_loop easy
Gradient with jax.grad easy
jit + grad Composition medium
JAX numpy vs PyTorch ops easy
PRNGKey and Split easy
Implement Depthwise-Separable Convolution hard
PRNGKey fold_in medium
Pure Function vs Impure easy
Pytree Leaves easy
Pytree Map medium
2-D Scatter via .at[].add() medium
Jit with static_argnames medium
Tracer Leak Detection medium
Tracing: Shape vs Value medium
grad vs vjp medium
jax.value_and_grad easy
Gradient Checkpointing: Basics medium
Checkpointed Deep Stack via scan medium
Checkpoint with Save Policy hard
Custom JVP: Clip with Pass-through Gradient medium
Custom VJP: Implicit Function Theorem hard
Custom VJP: Stable log1pexp medium
vmap(grad) vs grad(sum(vmap)) medium
HVP via grad-of-grad medium
HVP via jvp-of-grad medium
jacfwd vs jacrev medium
Per-Example Gradients via vmap(grad(...)) medium
vjp Basics easy
jvp Basics easy
jvp for Sensitivity Analysis medium
Microbatched Gradient Accumulation via scan hard
stop_gradient: Target Network easy
Straight-Through Estimator medium
Jacobian via Batched vjp medium
jit of scan medium
Nested vmap (vmap-of-vmap) hard
vmap with None broadcasting easy
vmap-cond vs where: cost tradeoff hard
vmap with in_axes medium
Multi-Branch Dispatch via lax.switch medium
Cumulative Sum via lax.scan easy
Scan over Layer Stack medium
Running Mean via lax.scan medium
vmap over lax.cond hard
Training Loop via lax.scan hard
Scan with Per-Step Outputs medium
vmap over lax.scan hard
Bounded Search via lax.while_loop hard
Newton's Method via lax.while_loop medium
lax.while_loop vs Python while medium
Module That Branches on a Config Flag medium
REINFORCE with Baseline hard
Dirichlet Sampling medium
Gumbel Argmax (Categorical via Trick) medium
Gumbel-Softmax hard
HMC Leapfrog Step hard
Importance Sampling hard
Bernoulli Mask Sampling easy
Categorical Sampling medium
Normal Sampling with mean and std easy
Uniform Sampling easy
Log Acceptance Ratio easy
Metropolis-Hastings Step medium
Nucleus (Top-p) Masking hard
Random Walk via lax.scan hard
REINFORCE Gradient Estimator medium
Reparameterization Trick: Gaussian medium
Reparameterization Gradient medium
Batched Sampling via vmap medium
Temperature Scaling easy
Top-k Logit Masking medium
bfloat16 Mixed Precision medium
block_until_ready for Sync easy
Host Roundtrip via device_get medium
Explicit device_put easy
disable_jit Context Manager medium
Explicit Dtype Promotion easy
Fused Elementwise Ops medium
jit with donate_argnums medium
Three Levels of Module Nesting medium
jit with static_argnames easy
jit with static_argnums medium
make_jaxpr Inspection medium
Mesh Creation medium
NamedSharding (Replicated) medium
jit with Pytree Input easy
PartitionSpec Basics medium
shard_map Basics hard
shard_map vs vmap hard
jit Retrace on Shape Change medium
shard_map with psum Collective hard
with_sharding_constraint hard
Explicit Broadcasting via broadcast_to easy
Clip and Extrema easy
NaN-Safe Mean (with mask) easy
take_along_axis (Gather) medium
Multi-Condition where medium
Grad through stop_gradient medium
Higher-Order custom_vjp hard
Jacobian via Mixed-Mode (jvp+vjp) hard
jax.linearize Primitive medium
Saved Residuals in custom_vjp medium
Associative Scan (Parallel Cumsum) hard
Dynamic Slice medium
Dynamic Update Slice medium
lax.map vs vmap medium
lax.scan with reverse=True medium
ELBO for Gaussian VI hard
Beta Distribution Sampling easy
Multivariate Normal Sampling medium
Random Permutation easy
Poisson Sampling medium
Compilation Cache Key medium
compilation_cache.set_cache_dir medium
Multi-Host JAX (Conceptual) hard
int8 Affine Quantization medium
pjit Basics hard
Shape Polymorphism (Concept) medium
jax.debug.callback medium
jax.debug.print easy
float8 Round-Trip medium
io_callback for Side Effects hard
Counting Jaxpr Equations medium
jaxpr with jit medium
jvp with Structured Tangents hard
JAX <-> NumPy Bridge easy
pure_callback Basics medium
tree_flatten / tree_unflatten Round-Trip medium
Pytree Aux Data hard
register_pytree_node medium
Selective stop_gradient (Frozen Layers) hard
Conv Padding & Strides medium
1-D FFT Magnitude medium
1-D Conv via lax.conv_general_dilated medium
2-D Conv via lax.conv_general_dilated hard
Cumulative Trapezoidal Integration hard
Cholesky Decomposition medium
slogdet for Stable log|det| medium
eigh: Symmetric Eigendecomp hard
linalg.inv easy
linalg.matrix_power medium
linalg.norm (Frobenius) easy
QR Decomposition medium
linalg.solve medium
Linear Interpolation medium
log1p / expm1 for Small Values easy
Numerically Stable logsumexp medium
Polynomial Fitting via polyfit medium
Polynomial Evaluation via polyval easy
Real FFT (rfft) medium
BCOO from Dense medium
BCOO Round-Trip easy
Sparse Matrix-Vector Product hard
Stable log_softmax medium
jnp.allclose for Tolerance Equality easy
jax.debug.callback for Side Effects medium
jax.debug.print with Multiple kwargs medium
Pretty-Printed jaxpr Length medium
AOT via jit.lower().compile() hard
@profiler.annotate_function easy
jax.named_scope for Profiler Grouping medium
Fix: Python for โ†’ vmap medium
Fix: List Append โ†’ jnp.concatenate medium
Fix: Python if โ†’ jnp.where medium
jaxpr with Multiple Args medium
profiler.StepTraceAnnotation medium
Optax Adam Step medium
AdamW with Decoupled Weight Decay medium
Adam + Weight Decay (Chain) medium
Chain: Clip + SGD medium
Constant Schedule easy
Optax EMA on Params medium
Exponential Decay Schedule easy
Full Training Step (Loss + Grad + Update) hard
Gradient Accumulation via MultiSteps hard
Global-Norm Gradient Clipping medium
inject_hyperparams for Runtime LR medium
Linear Schedule easy
Lookahead Optimizer Wrapper hard
Implement Conv1D from Scratch medium
masked: Apply WD Only To Certain Params hard
Optax SGD with Momentum easy
4-Step Training Loop with Scan + Loss Curve hard
multi_transform per-Param Group hard
Piecewise Constant Schedule medium
Optax RMSprop Step medium
Optax SGD Step easy
Train Step with Warmup Schedule hard
Train Step with Global-Norm Clipping hard
Train Step with Frozen Params (Mask) hard
Warmup + Cosine Decay medium
zero_nans for NaN-Safe Training medium
Implement BatchNorm with Mutable batch_stats hard
Implement Dropout with RNG Threading medium
init() and apply() Round-Trip medium
Implement Dense from Scratch medium
Implement RMSNorm (Modern LLM Norm) medium
Multiple PRNG Streams (params, dropout) medium
Implement LayerNorm with ฮณ/ฮฒ medium
Module with @nn.compact medium
nn.Sequential Composition medium
Implement Transposed Convolution medium
Module with setup() (alternative to compact) medium
Train vs Eval Branches via train Flag medium
Custom Parameter Initializer medium
Module with Multiple Named Sub-Modules medium
ALiBi Bias Matrix hard
DeiT โ€” Data-Efficient Image Transformer hard
Learned Position Embedding medium
ALiBi: Attention with Linear Biases hard
Multi-Head Self-Attention with Flax medium
Causal Multi-Head Self-Attention medium
Cross-Attention with Flax MHA medium
Block-Diagonal Attention Mask medium
Multi-Query Attention (MQA) hard
Grouped-Query Attention (GQA) hard
Mini BERT โ€” Encoder-Only Hidden States hard
Multi-Head Attention with KV Cache hard
Mini GPT โ€” Decoder-Only Language Model hard
ViT Patch Embedding medium
Sliding-Window Attention (Mistral-style) hard
Scaled Dot-Product Attention medium
Mini T5 โ€” Encoder-Decoder with RMSNorm and Tied Embeddings hard
Pre-LN vs Post-LN Residual Pattern medium
Rotary Position Embedding (RoPE) hard
Sinusoidal Position Encoding medium
Transformer Decoder Block (Pre-LN) hard
SwiGLU Feed-Forward Network medium
T5 Relative Position Bucketing hard
Transformer Encoder Block (Pre-LN) hard
Vision Transformer with [CLS] Token hard
Tied Input/Output Embedding medium
Token Embedding with Flax medium
Vision Transformer (Mean-Pool Variant) hard
Bidirectional RNN medium
eval_step โ€” Forward + Metrics medium
Gradient Accumulation Step hard
GRU Cell Step medium
Label-Smoothed Cross-Entropy medium
Train with Mutable batch_stats hard
ResNet Bottleneck Block hard
Warmup-Cosine LR at Step medium
Squeeze-and-Excitation Block medium
TrainState โ€” One Step medium
LSTM Cell Step medium
Mixed-Precision Training Step hard
Mixture-of-Experts FFN hard
ResNet Basic Block hard
Tiny ResNet Classifier hard
Multi-Task Two-Head Loss medium
Sharded Eval Loss hard
train_step with value_and_grad medium
Tiny U-Net hard
Vision-Language Fusion medium
Batched init via jax.vmap medium
Composed lifts: nn.scan + nn.vmap hard
Custom lift: roll your own ensemble hard
Distributed Checkpoint (Sharding Math) medium
EMA of Parameters medium
HF Weight Load (Kernel Transpose) hard
jax.lax.scan inside a Flax Module hard
Mini LM Capstone โ€” Putting It All Together hard
Multiple Mutable Collections medium
PartitionSpec Layout medium
nn.checkpoint (gradient checkpointing) medium
nn.vmap with shared params medium
Param Sharing โ€” One Module, Two Call Sites medium
Per-Param Weight Decay Mask hard
Param Surgery โ€” Zero Last Layer medium
nn.jit (Flax-aware JIT lift) medium
nn.remat with checkpoint policies hard
nn.with_partitioning Annotation medium
Orbax Load (Restore via Template) medium
shard_map Simulation โ€” Manual SPMD hard
nn.scan over layers hard
nn.scan over an RNN cell hard
Orbax Save (Tree-Leaf Count) medium
Param Freezing via Grad Zeroing hard
Partial Init โ€” Warm-Start From Smaller Checkpoint medium
Per-Param Learning Rate Multipliers hard
Param Surgery โ€” Freeze First Dense medium
with_sharding_constraint Annotation medium
Pre-train then Fine-tune (Frozen Trunk) hard
flax.struct.dataclass โ€” Pytree-Friendly State medium
Param Surgery โ€” Kernel Replace medium
Test-Time Augmentation Aggregation medium
NNX Call vs Init medium
NNX Conditional Submodule medium
NNX Counters and Buffers medium
NNX Deep Nesting medium
NNX Frozen Attribute medium
NNX Multi-Param Module medium
NNX State Split/Merge medium
NNX Variable vs Param medium
NNX GraphDef Introspection medium
NNX Mutability Demo medium
NNX State Transform medium
NNX Module Basics medium
nnx.Rngs Container medium
NNX Multi State Types medium
NNX Module Print State medium
NNX Running Average / EMA Step medium
NNX Train/Eval Flag medium
NNX Module With Submodule medium
NNX State Filter medium
NNX Pure State Update medium
NNX Implement BatchNorm hard
NNX Implement Conv1D medium
NNX Implement Conv2D medium
NNX Implement Dense medium
NNX Implement Dropout medium
NNX Implement LayerNorm medium
NNX Implement Positional Embed medium
NNX Implement Embed medium
NNX ALiBi Attention hard
NNX Multi-Query Attention (MQA) hard
NNX Mini-GPT hard
NNX ViT with CLS Token hard
NNX Implement GroupNorm medium
NNX Causal MHA medium
NNX RoPE Attention hard
NNX Transformer Decoder Block hard
NNX Vision Transformer hard
NNX Implement RMSNorm medium
NNX Grouped-Query Attention (GQA) hard
NNX Mini-BERT hard
NNX Cross-Attention medium
NNX Sliding-Window Attention hard
NNX Transformer Encoder Block hard
NNX MHA From Scratch hard
NNX MHA With KV Cache hard
NNX ResNet Basic Block hard
NNX Tiny ResNet Classifier hard
NNX Scaled Dot-Product Attention medium
NNX SwiGLU FFN medium
NNX Tiny U-Net hard
NNX Batched Init medium
NNX Checkpoint / Remat medium
NNX Composed Transforms hard
NNX Eval Step medium
NNX Fori Loop in Module hard
NNX LR Schedule medium
NNX Gradient Accumulation hard
NNX nnx.jit medium
NNX Mixed-Precision Step hard
NNX Remat Policies hard
NNX State Sharding medium
NNX Label Smoothing medium
NNX Multi-Task Loss medium
NNX MultiMetric medium
NNX Optimizer Basics medium
NNX vmap over Batch medium
NNX Scan Layers hard
NNX Scan RNN hard
NNX Train with BatchNorm hard
NNX Train Step medium
NNX vmap over Ensemble hard
NNX Bridge: Call Linen From NNX hard
NNX Bridge: Call NNX From Linen hard
NNX Bridge: Load HuggingFace Weights into NNX hard
NNX Bridge: Shared Params Across nnx and Linen hard
NNX Bridge: State Translation NNX โ†’ Linen medium
NNX Eager Debug medium
NNX Orbax Sharded Save medium
NNX Surgery Replace medium
NNX Bridge: Train Mixed Linen+NNX Model hard
NNX PartitionSpec Layout medium
NNX Surgery Zero Layer medium
NNX Coexistence: Parity With Linen medium
NNX Data-Parallel Step hard
NNX Debug Callback medium
NNX Debug Print medium
NNX FSDP-Style Step hard
NNX Distributed Train Step hard
NNX Graphdef vs State Debug medium
NNX Orbax Sharded Load medium
NNX Port: Linen Transformer Block โ†’ NNX hard
NNX Mesh Init Simulation medium
NNX Mini-LM Capstone โ€” Putting It All Together hard
NNX Port: Linen Dense โ†’ NNX Linear medium
NNX Tied I/O Embed medium
NNX Port: Linen MLP โ†’ NNX MLP medium
NNX Surgery Add Layer medium
NNX shard_map Simulation hard
NNX Surgery Freeze hard
NNX Tensor-Parallel Linear hard