|
|
Implement KL Divergence
|
medium
|
primitives |
|
|
|
MLM Eval โ Masked Accuracy
|
medium
|
end_to_end |
|
|
|
Matrix Multiplication
|
easy
|
framework |
|
|
|
Hessian of a Quadratic
|
easy
|
primitives |
|
|
|
Implement Sigmoid
|
easy
|
primitives |
|
|
|
Implement Mean Squared Error
|
easy
|
primitives |
|
|
|
SwiGLU Activation
|
medium
|
research |
|
|
|
Implement Tanh
|
easy
|
primitives |
|
|
|
Implement Cosine Similarity
|
easy
|
primitives |
|
|
|
Vectorize with vmap
|
medium
|
framework |
|
|
|
MLM Forward Pass
|
medium
|
end_to_end |
|
|
|
Flash Attention Score Computation
|
hard
|
research |
|
|
|
Adapter Layer
|
medium
|
end_to_end |
|
|
|
Transpose a Matrix
|
easy
|
framework |
|
|
|
Greedy Decoding
|
easy
|
primitives |
|
|
|
Implement Group Norm
|
medium
|
primitives |
|
|
|
Bidirectional RNN
|
medium
|
framework |
|
|
|
Causal Attention Mask
|
easy
|
primitives |
|
|
|
Implement Embedding Lookup
|
easy
|
primitives |
|
|
|
Training Loop
|
easy
|
end_to_end |
|
|
|
GLU Activation
|
easy
|
primitives |
|
|
|
Implement Conv2d
|
medium
|
primitives |
|
|
|
Concatenate Tensors
|
easy
|
framework |
|
|
|
Ridge Regression
|
easy
|
primitives |
|
|
|
2-D FFT DC Component
|
hard
|
primitives |
|
|
|
Implement Positional Encoding
|
hard
|
primitives |
|
|
|
Logistic Regression From Scratch
|
easy
|
end_to_end |
|
|
|
Transposed Conv2d
|
hard
|
primitives |
|
|
|
Data Augmentation
|
medium
|
end_to_end |
|
|
|
Implement GroupNorm
|
medium
|
primitives |
|
|
|
Broadcasting Addition
|
easy
|
framework |
|
|
|
Transformer Encoder Block
|
hard
|
end_to_end |
|
|
|
Multi-Class Classifier
|
easy
|
end_to_end |
|
|
|
Efficient Attention with Masking
|
hard
|
framework |
|
|
|
Top-k Sampling
|
medium
|
primitives |
|
|
|
Element-wise Operations
|
easy
|
framework |
|
|
|
Create a Tensor from a List
|
easy
|
framework |
|
|
|
Scatter Add
|
medium
|
framework |
|
|
|
Gradient Clipping
|
easy
|
end_to_end |
|
|
|
Triplet Loss
|
medium
|
primitives |
|
|
|
fori vs scan vs while
|
medium
|
primitives |
|
|
|
Accuracy โ Multiclass
|
easy
|
primitives |
|
|
|
Binary Classification via lax.cond
|
easy
|
primitives |
|
|
|
Transformer Decoder Block
|
hard
|
end_to_end |
|
|
|
Grouped-Query Attention
|
hard
|
research |
|
|
|
Multi-Query Attention
|
medium
|
research |
|
|
|
AUC-ROC
|
medium
|
primitives |
|
|
|
Custom Activation with Gradient
|
hard
|
framework |
|
|
|
Masked Fill
|
medium
|
framework |
|
|
|
vmap with out_axes
|
medium
|
primitives |
|
|
|
Top-K Values
|
medium
|
framework |
|
|
|
BLEU-1gram
|
medium
|
primitives |
|
|
|
Top-K Gating
|
medium
|
research |
|
|
|
ALiBi Position Bias
|
medium
|
research |
|
|
|
Relative Position Encoding
|
medium
|
research |
|
|
|
BPE Encode Text
|
hard
|
end_to_end |
|
|
|
BPE Merge Step
|
medium
|
primitives |
|
|
|
Tensor Indexing and Slicing
|
easy
|
framework |
|
|
|
Parallel Map with vmap
|
hard
|
framework |
|
|
|
int8 Dequantization
|
easy
|
primitives |
|
|
|
Nucleus (Top-P) Sampling
|
medium
|
research |
|
|
|
Classifier-Free Guidance
|
medium
|
primitives |
|
|
|
Feature Normalization Pipeline
|
easy
|
end_to_end |
|
|
|
Focal Loss
|
medium
|
research |
|
|
|
Confusion Matrix
|
easy
|
primitives |
|
|
|
Speculative Decoding Accept/Reject
|
hard
|
research |
|
|
|
Implement Softmax
|
easy
|
primitives |
|
|
|
Implement Leaky ReLU
|
easy
|
primitives |
|
|
|
Implement Binary Cross-Entropy Loss
|
medium
|
primitives |
|
|
|
Data Collator with Padding
|
medium
|
primitives |
|
|
|
DDPM Denoising Step
|
hard
|
primitives |
|
|
|
Mixture of Experts Routing
|
hard
|
research |
|
|
|
Implement Cross-Entropy Loss
|
medium
|
primitives |
|
|
|
DDPM Forward Noising
|
medium
|
primitives |
|
|
|
Prefix Tuning
|
hard
|
research |
|
|
|
DDPM Noise Schedule
|
easy
|
primitives |
|
|
|
Cumulative Sum
|
easy
|
framework |
|
|
|
SVD Singular Values
|
medium
|
primitives |
|
|
|
Implement Momentum Update
|
medium
|
primitives |
|
|
|
Eval Loop with Metrics
|
medium
|
end_to_end |
|
|
|
Expected Calibration Error
|
medium
|
primitives |
|
|
|
Exponential Moving Average
|
medium
|
primitives |
|
|
|
Gradient Accumulation
|
medium
|
end_to_end |
|
|
|
Mini-Batch Training
|
medium
|
end_to_end |
|
|
|
Learned Absolute Position Embedding
|
easy
|
primitives |
|
|
|
Beam Search
|
hard
|
end_to_end |
|
|
|
Model Checkpointing
|
easy
|
end_to_end |
|
|
|
Stack Tensors
|
easy
|
framework |
|
|
|
Perplexity from Logits
|
easy
|
primitives |
|
|
|
Implement Linear Layer
|
medium
|
primitives |
|
|
|
Tied Input/Output Embeddings
|
medium
|
primitives |
|
|
|
MLM Masking Strategy
|
medium
|
end_to_end |
|
|
|
Batched Matrix Multiply
|
medium
|
framework |
|
|
|
Word Embedding Model
|
medium
|
end_to_end |
|
|
|
Implement Dropout
|
medium
|
primitives |
|
|
|
Einstein Summation
|
medium
|
framework |
|
|
|
Implement One-Hot Encoding
|
easy
|
primitives |
|
|
|
Precision, Recall & F1
|
medium
|
primitives |
|
|
|
Temperature Scaling
|
easy
|
research |
|
|
|
Polynomial Regression
|
medium
|
end_to_end |
|
|
|
Autoencoder
|
medium
|
end_to_end |
|
|
|
Ring All-Reduce
|
hard
|
primitives |
|
|
|
Implement Gradient Descent Step
|
easy
|
primitives |
|
|
|
GELU Activation
|
easy
|
research |
|
|
|
Subword Tokenizer: Greedy Longest-Prefix-Match
|
medium
|
primitives |
|
|
|
Linear Regression
|
easy
|
end_to_end |
|
|
|
Implement ReLU
|
easy
|
primitives |
|
|
|
Tokenize and Pad Batch
|
medium
|
end_to_end |
|
|
|
Weight Initialization
|
easy
|
end_to_end |
|
|
|
Reshape a Tensor
|
easy
|
framework |
|
|
|
JIT Compile a Function
|
medium
|
framework |
|
|
|
ViT with Mixup Augmentation
|
medium
|
end_to_end |
|
|
|
Contrastive Loss (InfoNCE)
|
medium
|
research |
|
|
|
VAE ELBO Loss
|
medium
|
primitives |
|
|
|
Depthwise Separable Convolution
|
hard
|
research |
|
|
|
Simple RNN Cell
|
medium
|
end_to_end |
|
|
|
VAE Reparameterization Trick
|
easy
|
primitives |
|
|
|
Rotary Position Embeddings
|
hard
|
research |
|
|
|
Binary Classifier
|
easy
|
end_to_end |
|
|
|
LoRA Update
|
medium
|
research |
|
|
|
Causal Self-Attention Block
|
medium
|
end_to_end |
|
|
|
Cross Attention
|
medium
|
research |
|
|
|
Distributed Training Step End-to-End
|
hard
|
end_to_end |
|
|
|
Gather Elements
|
medium
|
framework |
|
|
|
Compute Jacobian
|
hard
|
framework |
|
|
|
DDPM Sampler Loop
|
hard
|
end_to_end |
|
|
|
DDPM Training Step End-to-End
|
medium
|
end_to_end |
|
|
|
Simple GAN Generator
|
hard
|
end_to_end |
|
|
|
Learning Rate Scheduler
|
medium
|
end_to_end |
|
|
|
Mixed Precision Forward Pass
|
hard
|
framework |
|
|
|
Multi-Head Attention Block
|
hard
|
end_to_end |
|
|
|
RMSNorm
|
easy
|
research |
|
|
|
Train Binary Classifier End-to-End
|
medium
|
end_to_end |
|
|
|
Nucleus Sampling Generation Loop
|
hard
|
end_to_end |
|
|
|
Label Smoothing
|
easy
|
research |
|
|
|
Two-Layer MLP
|
easy
|
end_to_end |
|
|
|
GRU Cell
|
medium
|
end_to_end |
|
|
|
Implement Scaled Dot-Product Attention
|
hard
|
primitives |
|
|
|
Implement Batch Normalization
|
medium
|
primitives |
|
|
|
Train Image Classifier End-to-End
|
medium
|
end_to_end |
|
|
|
Implement Max Pooling 1D
|
medium
|
primitives |
|
|
|
Train LoRA Adapter End-to-End
|
medium
|
end_to_end |
|
|
|
LSTM Cell
|
medium
|
end_to_end |
|
|
|
Train Multiclass Classifier with Early Stopping
|
medium
|
end_to_end |
|
|
|
Train RNN Language Model
|
hard
|
end_to_end |
|
|
|
Train VAE End-to-End
|
medium
|
end_to_end |
|
|
|
Sliding Window Attention
|
medium
|
research |
|
|
|
Train with Adam End-to-End
|
medium
|
end_to_end |
|
|
|
Implement Conv2D with Stride and Padding
|
medium
|
primitives |
|
|
|
Skip Connection Block
|
medium
|
end_to_end |
|
|
|
Transformer Block with RoPE
|
hard
|
end_to_end |
|
|
|
Implement 1D Convolution
|
hard
|
primitives |
|
|
|
Implement Adam Optimizer Step
|
hard
|
primitives |
|
|
|
Implement Average Pooling 1D
|
medium
|
primitives |
|
|
|
Causal LM Forward Pass
|
medium
|
end_to_end |
|
|
|
Sequence Classifier
|
medium
|
end_to_end |
|
|
|
Implement L2 Regularization
|
easy
|
primitives |
|
|
|
Compute Gradient
|
medium
|
framework |
|
|
|
Functional Update with .at[].set()
|
easy
|
primitives |
|
|
|
Self-Attention Layer
|
hard
|
end_to_end |
|
|
|
LLaMA-Style Transformer Block
|
hard
|
end_to_end |
|
|
|
KV Cache for Autoregressive Decoding
|
hard
|
research |
|
|
|
Apply Along an Axis
|
medium
|
framework |
|
|
|
Causal LM with KV Cache Generation
|
hard
|
end_to_end |
|
|
|
Implement Layer Normalization
|
medium
|
primitives |
|
|
|
Cross-Attention Block
|
medium
|
end_to_end |
|
|
|
Stochastic Depth
|
medium
|
research |
|
|
|
Encoder-Decoder Transformer Forward Pass
|
hard
|
end_to_end |
|
|
|
Squeeze-and-Excitation Block
|
medium
|
research |
|
|
|
Simple CNN
|
medium
|
end_to_end |
|
|
|
Encoder-Decoder Greedy Decode
|
medium
|
end_to_end |
|
|
|
Encoder-Decoder Beam Search
|
hard
|
end_to_end |
|
|
|
MLM Forward with Tied Output Head
|
medium
|
end_to_end |
|
|
|
Train ViT Classifier End-to-End
|
hard
|
end_to_end |
|
|
|
ViT Classification Forward
|
hard
|
end_to_end |
|
|
|
ViT Encoder Block
|
medium
|
end_to_end |
|
|
|
Patch Embedding with CLS Token
|
medium
|
end_to_end |
|
|
|
Train Causal LM Pretraining Step
|
hard
|
end_to_end |
|
|
|
Train Encoder-Decoder Seq2Seq Step
|
hard
|
end_to_end |
|
|
|
Train MLM Pretraining Step
|
hard
|
end_to_end |
|
|
|
Gradient over Pytree Params
|
medium
|
primitives |
|
|
|
Train Tiny GPT End-to-End
|
hard
|
end_to_end |
|
|
|
Deterministic Batch via vmap+split
|
medium
|
primitives |
|
|
|
Dtype Promotion Rules
|
hard
|
primitives |
|
|
|
Functional Update with .at[].add()
|
easy
|
primitives |
|
|
|
x^n via lax.fori_loop
|
easy
|
primitives |
|
|
|
Gradient with jax.grad
|
easy
|
primitives |
|
|
|
jit + grad Composition
|
medium
|
primitives |
|
|
|
JAX numpy vs PyTorch ops
|
easy
|
primitives |
|
|
|
PRNGKey and Split
|
easy
|
primitives |
|
|
|
Implement Depthwise-Separable Convolution
|
hard
|
primitives |
|
|
|
PRNGKey fold_in
|
medium
|
primitives |
|
|
|
Pure Function vs Impure
|
easy
|
end_to_end |
|
|
|
Pytree Leaves
|
easy
|
primitives |
|
|
|
Pytree Map
|
medium
|
primitives |
|
|
|
2-D Scatter via .at[].add()
|
medium
|
primitives |
|
|
|
Jit with static_argnames
|
medium
|
end_to_end |
|
|
|
Tracer Leak Detection
|
medium
|
end_to_end |
|
|
|
Tracing: Shape vs Value
|
medium
|
end_to_end |
|
|
|
grad vs vjp
|
medium
|
primitives |
|
|
|
jax.value_and_grad
|
easy
|
primitives |
|
|
|
Gradient Checkpointing: Basics
|
medium
|
primitives |
|
|
|
Checkpointed Deep Stack via scan
|
medium
|
primitives |
|
|
|
Checkpoint with Save Policy
|
hard
|
primitives |
|
|
|
Custom JVP: Clip with Pass-through Gradient
|
medium
|
primitives |
|
|
|
Custom VJP: Implicit Function Theorem
|
hard
|
primitives |
|
|
|
Custom VJP: Stable log1pexp
|
medium
|
primitives |
|
|
|
vmap(grad) vs grad(sum(vmap))
|
medium
|
primitives |
|
|
|
HVP via grad-of-grad
|
medium
|
primitives |
|
|
|
HVP via jvp-of-grad
|
medium
|
primitives |
|
|
|
jacfwd vs jacrev
|
medium
|
primitives |
|
|
|
Per-Example Gradients via vmap(grad(...))
|
medium
|
primitives |
|
|
|
vjp Basics
|
easy
|
primitives |
|
|
|
jvp Basics
|
easy
|
primitives |
|
|
|
jvp for Sensitivity Analysis
|
medium
|
primitives |
|
|
|
Microbatched Gradient Accumulation via scan
|
hard
|
primitives |
|
|
|
stop_gradient: Target Network
|
easy
|
primitives |
|
|
|
Straight-Through Estimator
|
medium
|
primitives |
|
|
|
Jacobian via Batched vjp
|
medium
|
primitives |
|
|
|
jit of scan
|
medium
|
primitives |
|
|
|
Nested vmap (vmap-of-vmap)
|
hard
|
primitives |
|
|
|
vmap with None broadcasting
|
easy
|
primitives |
|
|
|
vmap-cond vs where: cost tradeoff
|
hard
|
primitives |
|
|
|
vmap with in_axes
|
medium
|
primitives |
|
|
|
Multi-Branch Dispatch via lax.switch
|
medium
|
primitives |
|
|
|
Cumulative Sum via lax.scan
|
easy
|
primitives |
|
|
|
Scan over Layer Stack
|
medium
|
primitives |
|
|
|
Running Mean via lax.scan
|
medium
|
primitives |
|
|
|
vmap over lax.cond
|
hard
|
primitives |
|
|
|
Training Loop via lax.scan
|
hard
|
primitives |
|
|
|
Scan with Per-Step Outputs
|
medium
|
primitives |
|
|
|
vmap over lax.scan
|
hard
|
primitives |
|
|
|
Bounded Search via lax.while_loop
|
hard
|
primitives |
|
|
|
Newton's Method via lax.while_loop
|
medium
|
primitives |
|
|
|
lax.while_loop vs Python while
|
medium
|
primitives |
|
|
|
Module That Branches on a Config Flag
|
medium
|
primitives |
|
|
|
REINFORCE with Baseline
|
hard
|
primitives |
|
|
|
Dirichlet Sampling
|
medium
|
primitives |
|
|
|
Gumbel Argmax (Categorical via Trick)
|
medium
|
primitives |
|
|
|
Gumbel-Softmax
|
hard
|
primitives |
|
|
|
HMC Leapfrog Step
|
hard
|
primitives |
|
|
|
Importance Sampling
|
hard
|
primitives |
|
|
|
Bernoulli Mask Sampling
|
easy
|
primitives |
|
|
|
Categorical Sampling
|
medium
|
primitives |
|
|
|
Normal Sampling with mean and std
|
easy
|
primitives |
|
|
|
Uniform Sampling
|
easy
|
primitives |
|
|
|
Log Acceptance Ratio
|
easy
|
primitives |
|
|
|
Metropolis-Hastings Step
|
medium
|
primitives |
|
|
|
Nucleus (Top-p) Masking
|
hard
|
primitives |
|
|
|
Random Walk via lax.scan
|
hard
|
primitives |
|
|
|
REINFORCE Gradient Estimator
|
medium
|
primitives |
|
|
|
Reparameterization Trick: Gaussian
|
medium
|
primitives |
|
|
|
Reparameterization Gradient
|
medium
|
primitives |
|
|
|
Batched Sampling via vmap
|
medium
|
primitives |
|
|
|
Temperature Scaling
|
easy
|
primitives |
|
|
|
Top-k Logit Masking
|
medium
|
primitives |
|
|
|
bfloat16 Mixed Precision
|
medium
|
primitives |
|
|
|
block_until_ready for Sync
|
easy
|
primitives |
|
|
|
Host Roundtrip via device_get
|
medium
|
primitives |
|
|
|
Explicit device_put
|
easy
|
primitives |
|
|
|
disable_jit Context Manager
|
medium
|
primitives |
|
|
|
Explicit Dtype Promotion
|
easy
|
primitives |
|
|
|
Fused Elementwise Ops
|
medium
|
primitives |
|
|
|
jit with donate_argnums
|
medium
|
primitives |
|
|
|
Three Levels of Module Nesting
|
medium
|
primitives |
|
|
|
jit with static_argnames
|
easy
|
primitives |
|
|
|
jit with static_argnums
|
medium
|
primitives |
|
|
|
make_jaxpr Inspection
|
medium
|
primitives |
|
|
|
Mesh Creation
|
medium
|
primitives |
|
|
|
NamedSharding (Replicated)
|
medium
|
primitives |
|
|
|
jit with Pytree Input
|
easy
|
primitives |
|
|
|
PartitionSpec Basics
|
medium
|
primitives |
|
|
|
shard_map Basics
|
hard
|
primitives |
|
|
|
shard_map vs vmap
|
hard
|
primitives |
|
|
|
jit Retrace on Shape Change
|
medium
|
primitives |
|
|
|
shard_map with psum Collective
|
hard
|
primitives |
|
|
|
with_sharding_constraint
|
hard
|
primitives |
|
|
|
Explicit Broadcasting via broadcast_to
|
easy
|
primitives |
|
|
|
Clip and Extrema
|
easy
|
primitives |
|
|
|
NaN-Safe Mean (with mask)
|
easy
|
primitives |
|
|
|
take_along_axis (Gather)
|
medium
|
primitives |
|
|
|
Multi-Condition where
|
medium
|
primitives |
|
|
|
Grad through stop_gradient
|
medium
|
primitives |
|
|
|
Higher-Order custom_vjp
|
hard
|
primitives |
|
|
|
Jacobian via Mixed-Mode (jvp+vjp)
|
hard
|
primitives |
|
|
|
jax.linearize Primitive
|
medium
|
primitives |
|
|
|
Saved Residuals in custom_vjp
|
medium
|
primitives |
|
|
|
Associative Scan (Parallel Cumsum)
|
hard
|
primitives |
|
|
|
Dynamic Slice
|
medium
|
primitives |
|
|
|
Dynamic Update Slice
|
medium
|
primitives |
|
|
|
lax.map vs vmap
|
medium
|
primitives |
|
|
|
lax.scan with reverse=True
|
medium
|
primitives |
|
|
|
ELBO for Gaussian VI
|
hard
|
primitives |
|
|
|
Beta Distribution Sampling
|
easy
|
primitives |
|
|
|
Multivariate Normal Sampling
|
medium
|
primitives |
|
|
|
Random Permutation
|
easy
|
primitives |
|
|
|
Poisson Sampling
|
medium
|
primitives |
|
|
|
Compilation Cache Key
|
medium
|
primitives |
|
|
|
compilation_cache.set_cache_dir
|
medium
|
primitives |
|
|
|
Multi-Host JAX (Conceptual)
|
hard
|
primitives |
|
|
|
int8 Affine Quantization
|
medium
|
primitives |
|
|
|
pjit Basics
|
hard
|
primitives |
|
|
|
Shape Polymorphism (Concept)
|
medium
|
primitives |
|
|
|
jax.debug.callback
|
medium
|
primitives |
|
|
|
jax.debug.print
|
easy
|
primitives |
|
|
|
float8 Round-Trip
|
medium
|
primitives |
|
|
|
io_callback for Side Effects
|
hard
|
primitives |
|
|
|
Counting Jaxpr Equations
|
medium
|
primitives |
|
|
|
jaxpr with jit
|
medium
|
primitives |
|
|
|
jvp with Structured Tangents
|
hard
|
primitives |
|
|
|
JAX <-> NumPy Bridge
|
easy
|
primitives |
|
|
|
pure_callback Basics
|
medium
|
primitives |
|
|
|
tree_flatten / tree_unflatten Round-Trip
|
medium
|
primitives |
|
|
|
Pytree Aux Data
|
hard
|
primitives |
|
|
|
register_pytree_node
|
medium
|
primitives |
|
|
|
Selective stop_gradient (Frozen Layers)
|
hard
|
primitives |
|
|
|
Conv Padding & Strides
|
medium
|
primitives |
|
|
|
1-D FFT Magnitude
|
medium
|
primitives |
|
|
|
1-D Conv via lax.conv_general_dilated
|
medium
|
primitives |
|
|
|
2-D Conv via lax.conv_general_dilated
|
hard
|
primitives |
|
|
|
Cumulative Trapezoidal Integration
|
hard
|
primitives |
|
|
|
Cholesky Decomposition
|
medium
|
primitives |
|
|
|
slogdet for Stable log|det|
|
medium
|
primitives |
|
|
|
eigh: Symmetric Eigendecomp
|
hard
|
primitives |
|
|
|
linalg.inv
|
easy
|
primitives |
|
|
|
linalg.matrix_power
|
medium
|
primitives |
|
|
|
linalg.norm (Frobenius)
|
easy
|
primitives |
|
|
|
QR Decomposition
|
medium
|
primitives |
|
|
|
linalg.solve
|
medium
|
primitives |
|
|
|
Linear Interpolation
|
medium
|
primitives |
|
|
|
log1p / expm1 for Small Values
|
easy
|
primitives |
|
|
|
Numerically Stable logsumexp
|
medium
|
primitives |
|
|
|
Polynomial Fitting via polyfit
|
medium
|
primitives |
|
|
|
Polynomial Evaluation via polyval
|
easy
|
primitives |
|
|
|
Real FFT (rfft)
|
medium
|
primitives |
|
|
|
BCOO from Dense
|
medium
|
primitives |
|
|
|
BCOO Round-Trip
|
easy
|
primitives |
|
|
|
Sparse Matrix-Vector Product
|
hard
|
primitives |
|
|
|
Stable log_softmax
|
medium
|
primitives |
|
|
|
jnp.allclose for Tolerance Equality
|
easy
|
primitives |
|
|
|
jax.debug.callback for Side Effects
|
medium
|
primitives |
|
|
|
jax.debug.print with Multiple kwargs
|
medium
|
primitives |
|
|
|
Pretty-Printed jaxpr Length
|
medium
|
primitives |
|
|
|
AOT via jit.lower().compile()
|
hard
|
primitives |
|
|
|
@profiler.annotate_function
|
easy
|
primitives |
|
|
|
jax.named_scope for Profiler Grouping
|
medium
|
primitives |
|
|
|
Fix: Python for โ vmap
|
medium
|
primitives |
|
|
|
Fix: List Append โ jnp.concatenate
|
medium
|
primitives |
|
|
|
Fix: Python if โ jnp.where
|
medium
|
primitives |
|
|
|
jaxpr with Multiple Args
|
medium
|
primitives |
|
|
|
profiler.StepTraceAnnotation
|
medium
|
primitives |
|
|
|
Optax Adam Step
|
medium
|
primitives |
|
|
|
AdamW with Decoupled Weight Decay
|
medium
|
primitives |
|
|
|
Adam + Weight Decay (Chain)
|
medium
|
primitives |
|
|
|
Chain: Clip + SGD
|
medium
|
primitives |
|
|
|
Constant Schedule
|
easy
|
primitives |
|
|
|
Optax EMA on Params
|
medium
|
primitives |
|
|
|
Exponential Decay Schedule
|
easy
|
primitives |
|
|
|
Full Training Step (Loss + Grad + Update)
|
hard
|
primitives |
|
|
|
Gradient Accumulation via MultiSteps
|
hard
|
primitives |
|
|
|
Global-Norm Gradient Clipping
|
medium
|
primitives |
|
|
|
inject_hyperparams for Runtime LR
|
medium
|
primitives |
|
|
|
Linear Schedule
|
easy
|
primitives |
|
|
|
Lookahead Optimizer Wrapper
|
hard
|
primitives |
|
|
|
Implement Conv1D from Scratch
|
medium
|
primitives |
|
|
|
masked: Apply WD Only To Certain Params
|
hard
|
primitives |
|
|
|
Optax SGD with Momentum
|
easy
|
primitives |
|
|
|
4-Step Training Loop with Scan + Loss Curve
|
hard
|
primitives |
|
|
|
multi_transform per-Param Group
|
hard
|
primitives |
|
|
|
Piecewise Constant Schedule
|
medium
|
primitives |
|
|
|
Optax RMSprop Step
|
medium
|
primitives |
|
|
|
Optax SGD Step
|
easy
|
primitives |
|
|
|
Train Step with Warmup Schedule
|
hard
|
primitives |
|
|
|
Train Step with Global-Norm Clipping
|
hard
|
primitives |
|
|
|
Train Step with Frozen Params (Mask)
|
hard
|
primitives |
|
|
|
Warmup + Cosine Decay
|
medium
|
primitives |
|
|
|
zero_nans for NaN-Safe Training
|
medium
|
primitives |
|
|
|
Implement BatchNorm with Mutable batch_stats
|
hard
|
primitives |
|
|
|
Implement Dropout with RNG Threading
|
medium
|
primitives |
|
|
|
init() and apply() Round-Trip
|
medium
|
primitives |
|
|
|
Implement Dense from Scratch
|
medium
|
primitives |
|
|
|
Implement RMSNorm (Modern LLM Norm)
|
medium
|
primitives |
|
|
|
Multiple PRNG Streams (params, dropout)
|
medium
|
primitives |
|
|
|
Implement LayerNorm with ฮณ/ฮฒ
|
medium
|
primitives |
|
|
|
Module with @nn.compact
|
medium
|
primitives |
|
|
|
nn.Sequential Composition
|
medium
|
primitives |
|
|
|
Implement Transposed Convolution
|
medium
|
primitives |
|
|
|
Module with setup() (alternative to compact)
|
medium
|
primitives |
|
|
|
Train vs Eval Branches via train Flag
|
medium
|
primitives |
|
|
|
Custom Parameter Initializer
|
medium
|
primitives |
|
|
|
Module with Multiple Named Sub-Modules
|
medium
|
primitives |
|
|
|
ALiBi Bias Matrix
|
hard
|
primitives |
|
|
|
DeiT โ Data-Efficient Image Transformer
|
hard
|
primitives |
|
|
|
Learned Position Embedding
|
medium
|
primitives |
|
|
|
ALiBi: Attention with Linear Biases
|
hard
|
primitives |
|
|
|
Multi-Head Self-Attention with Flax
|
medium
|
primitives |
|
|
|
Causal Multi-Head Self-Attention
|
medium
|
primitives |
|
|
|
Cross-Attention with Flax MHA
|
medium
|
primitives |
|
|
|
Block-Diagonal Attention Mask
|
medium
|
primitives |
|
|
|
Multi-Query Attention (MQA)
|
hard
|
primitives |
|
|
|
Grouped-Query Attention (GQA)
|
hard
|
primitives |
|
|
|
Mini BERT โ Encoder-Only Hidden States
|
hard
|
primitives |
|
|
|
Multi-Head Attention with KV Cache
|
hard
|
primitives |
|
|
|
Mini GPT โ Decoder-Only Language Model
|
hard
|
primitives |
|
|
|
ViT Patch Embedding
|
medium
|
primitives |
|
|
|
Sliding-Window Attention (Mistral-style)
|
hard
|
primitives |
|
|
|
Scaled Dot-Product Attention
|
medium
|
primitives |
|
|
|
Mini T5 โ Encoder-Decoder with RMSNorm and Tied Embeddings
|
hard
|
primitives |
|
|
|
Pre-LN vs Post-LN Residual Pattern
|
medium
|
primitives |
|
|
|
Rotary Position Embedding (RoPE)
|
hard
|
primitives |
|
|
|
Sinusoidal Position Encoding
|
medium
|
primitives |
|
|
|
Transformer Decoder Block (Pre-LN)
|
hard
|
primitives |
|
|
|
SwiGLU Feed-Forward Network
|
medium
|
primitives |
|
|
|
T5 Relative Position Bucketing
|
hard
|
primitives |
|
|
|
Transformer Encoder Block (Pre-LN)
|
hard
|
primitives |
|
|
|
Vision Transformer with [CLS] Token
|
hard
|
primitives |
|
|
|
Tied Input/Output Embedding
|
medium
|
primitives |
|
|
|
Token Embedding with Flax
|
medium
|
primitives |
|
|
|
Vision Transformer (Mean-Pool Variant)
|
hard
|
primitives |
|
|
|
Bidirectional RNN
|
medium
|
primitives |
|
|
|
eval_step โ Forward + Metrics
|
medium
|
primitives |
|
|
|
Gradient Accumulation Step
|
hard
|
primitives |
|
|
|
GRU Cell Step
|
medium
|
primitives |
|
|
|
Label-Smoothed Cross-Entropy
|
medium
|
primitives |
|
|
|
Train with Mutable batch_stats
|
hard
|
primitives |
|
|
|
ResNet Bottleneck Block
|
hard
|
primitives |
|
|
|
Warmup-Cosine LR at Step
|
medium
|
primitives |
|
|
|
Squeeze-and-Excitation Block
|
medium
|
primitives |
|
|
|
TrainState โ One Step
|
medium
|
primitives |
|
|
|
LSTM Cell Step
|
medium
|
primitives |
|
|
|
Mixed-Precision Training Step
|
hard
|
primitives |
|
|
|
Mixture-of-Experts FFN
|
hard
|
primitives |
|
|
|
ResNet Basic Block
|
hard
|
primitives |
|
|
|
Tiny ResNet Classifier
|
hard
|
primitives |
|
|
|
Multi-Task Two-Head Loss
|
medium
|
primitives |
|
|
|
Sharded Eval Loss
|
hard
|
primitives |
|
|
|
train_step with value_and_grad
|
medium
|
primitives |
|
|
|
Tiny U-Net
|
hard
|
primitives |
|
|
|
Vision-Language Fusion
|
medium
|
primitives |
|
|
|
Batched init via jax.vmap
|
medium
|
primitives |
|
|
|
Composed lifts: nn.scan + nn.vmap
|
hard
|
primitives |
|
|
|
Custom lift: roll your own ensemble
|
hard
|
primitives |
|
|
|
Distributed Checkpoint (Sharding Math)
|
medium
|
primitives |
|
|
|
EMA of Parameters
|
medium
|
primitives |
|
|
|
HF Weight Load (Kernel Transpose)
|
hard
|
primitives |
|
|
|
jax.lax.scan inside a Flax Module
|
hard
|
primitives |
|
|
|
Mini LM Capstone โ Putting It All Together
|
hard
|
primitives |
|
|
|
Multiple Mutable Collections
|
medium
|
primitives |
|
|
|
PartitionSpec Layout
|
medium
|
primitives |
|
|
|
nn.checkpoint (gradient checkpointing)
|
medium
|
primitives |
|
|
|
nn.vmap with shared params
|
medium
|
primitives |
|
|
|
Param Sharing โ One Module, Two Call Sites
|
medium
|
primitives |
|
|
|
Per-Param Weight Decay Mask
|
hard
|
primitives |
|
|
|
Param Surgery โ Zero Last Layer
|
medium
|
primitives |
|
|
|
nn.jit (Flax-aware JIT lift)
|
medium
|
primitives |
|
|
|
nn.remat with checkpoint policies
|
hard
|
primitives |
|
|
|
nn.with_partitioning Annotation
|
medium
|
primitives |
|
|
|
Orbax Load (Restore via Template)
|
medium
|
primitives |
|
|
|
shard_map Simulation โ Manual SPMD
|
hard
|
primitives |
|
|
|
nn.scan over layers
|
hard
|
primitives |
|
|
|
nn.scan over an RNN cell
|
hard
|
primitives |
|
|
|
Orbax Save (Tree-Leaf Count)
|
medium
|
primitives |
|
|
|
Param Freezing via Grad Zeroing
|
hard
|
primitives |
|
|
|
Partial Init โ Warm-Start From Smaller Checkpoint
|
medium
|
primitives |
|
|
|
Per-Param Learning Rate Multipliers
|
hard
|
primitives |
|
|
|
Param Surgery โ Freeze First Dense
|
medium
|
primitives |
|
|
|
with_sharding_constraint Annotation
|
medium
|
primitives |
|
|
|
Pre-train then Fine-tune (Frozen Trunk)
|
hard
|
primitives |
|
|
|
flax.struct.dataclass โ Pytree-Friendly State
|
medium
|
primitives |
|
|
|
Param Surgery โ Kernel Replace
|
medium
|
primitives |
|
|
|
Test-Time Augmentation Aggregation
|
medium
|
primitives |
|
|
|
NNX Call vs Init
|
medium
|
primitives |
|
|
|
NNX Conditional Submodule
|
medium
|
primitives |
|
|
|
NNX Counters and Buffers
|
medium
|
primitives |
|
|
|
NNX Deep Nesting
|
medium
|
primitives |
|
|
|
NNX Frozen Attribute
|
medium
|
primitives |
|
|
|
NNX Multi-Param Module
|
medium
|
primitives |
|
|
|
NNX State Split/Merge
|
medium
|
primitives |
|
|
|
NNX Variable vs Param
|
medium
|
primitives |
|
|
|
NNX GraphDef Introspection
|
medium
|
primitives |
|
|
|
NNX Mutability Demo
|
medium
|
primitives |
|
|
|
NNX State Transform
|
medium
|
primitives |
|
|
|
NNX Module Basics
|
medium
|
primitives |
|
|
|
nnx.Rngs Container
|
medium
|
primitives |
|
|
|
NNX Multi State Types
|
medium
|
primitives |
|
|
|
NNX Module Print State
|
medium
|
primitives |
|
|
|
NNX Running Average / EMA Step
|
medium
|
primitives |
|
|
|
NNX Train/Eval Flag
|
medium
|
primitives |
|
|
|
NNX Module With Submodule
|
medium
|
primitives |
|
|
|
NNX State Filter
|
medium
|
primitives |
|
|
|
NNX Pure State Update
|
medium
|
primitives |
|
|
|
NNX Implement BatchNorm
|
hard
|
primitives |
|
|
|
NNX Implement Conv1D
|
medium
|
primitives |
|
|
|
NNX Implement Conv2D
|
medium
|
primitives |
|
|
|
NNX Implement Dense
|
medium
|
primitives |
|
|
|
NNX Implement Dropout
|
medium
|
primitives |
|
|
|
NNX Implement LayerNorm
|
medium
|
primitives |
|
|
|
NNX Implement Positional Embed
|
medium
|
primitives |
|
|
|
NNX Implement Embed
|
medium
|
primitives |
|
|
|
NNX ALiBi Attention
|
hard
|
primitives |
|
|
|
NNX Multi-Query Attention (MQA)
|
hard
|
primitives |
|
|
|
NNX Mini-GPT
|
hard
|
primitives |
|
|
|
NNX ViT with CLS Token
|
hard
|
primitives |
|
|
|
NNX Implement GroupNorm
|
medium
|
primitives |
|
|
|
NNX Causal MHA
|
medium
|
primitives |
|
|
|
NNX RoPE Attention
|
hard
|
primitives |
|
|
|
NNX Transformer Decoder Block
|
hard
|
primitives |
|
|
|
NNX Vision Transformer
|
hard
|
primitives |
|
|
|
NNX Implement RMSNorm
|
medium
|
primitives |
|
|
|
NNX Grouped-Query Attention (GQA)
|
hard
|
primitives |
|
|
|
NNX Mini-BERT
|
hard
|
primitives |
|
|
|
NNX Cross-Attention
|
medium
|
primitives |
|
|
|
NNX Sliding-Window Attention
|
hard
|
primitives |
|
|
|
NNX Transformer Encoder Block
|
hard
|
primitives |
|
|
|
NNX MHA From Scratch
|
hard
|
primitives |
|
|
|
NNX MHA With KV Cache
|
hard
|
primitives |
|
|
|
NNX ResNet Basic Block
|
hard
|
primitives |
|
|
|
NNX Tiny ResNet Classifier
|
hard
|
primitives |
|
|
|
NNX Scaled Dot-Product Attention
|
medium
|
primitives |
|
|
|
NNX SwiGLU FFN
|
medium
|
primitives |
|
|
|
NNX Tiny U-Net
|
hard
|
primitives |
|
|
|
NNX Batched Init
|
medium
|
primitives |
|
|
|
NNX Checkpoint / Remat
|
medium
|
primitives |
|
|
|
NNX Composed Transforms
|
hard
|
primitives |
|
|
|
NNX Eval Step
|
medium
|
primitives |
|
|
|
NNX Fori Loop in Module
|
hard
|
primitives |
|
|
|
NNX LR Schedule
|
medium
|
primitives |
|
|
|
NNX Gradient Accumulation
|
hard
|
primitives |
|
|
|
NNX nnx.jit
|
medium
|
primitives |
|
|
|
NNX Mixed-Precision Step
|
hard
|
primitives |
|
|
|
NNX Remat Policies
|
hard
|
primitives |
|
|
|
NNX State Sharding
|
medium
|
primitives |
|
|
|
NNX Label Smoothing
|
medium
|
primitives |
|
|
|
NNX Multi-Task Loss
|
medium
|
primitives |
|
|
|
NNX MultiMetric
|
medium
|
primitives |
|
|
|
NNX Optimizer Basics
|
medium
|
primitives |
|
|
|
NNX vmap over Batch
|
medium
|
primitives |
|
|
|
NNX Scan Layers
|
hard
|
primitives |
|
|
|
NNX Scan RNN
|
hard
|
primitives |
|
|
|
NNX Train with BatchNorm
|
hard
|
primitives |
|
|
|
NNX Train Step
|
medium
|
primitives |
|
|
|
NNX vmap over Ensemble
|
hard
|
primitives |
|
|
|
NNX Bridge: Call Linen From NNX
|
hard
|
primitives |
|
|
|
NNX Bridge: Call NNX From Linen
|
hard
|
primitives |
|
|
|
NNX Bridge: Load HuggingFace Weights into NNX
|
hard
|
primitives |
|
|
|
NNX Bridge: Shared Params Across nnx and Linen
|
hard
|
primitives |
|
|
|
NNX Bridge: State Translation NNX โ Linen
|
medium
|
primitives |
|
|
|
NNX Eager Debug
|
medium
|
primitives |
|
|
|
NNX Orbax Sharded Save
|
medium
|
primitives |
|
|
|
NNX Surgery Replace
|
medium
|
primitives |
|
|
|
NNX Bridge: Train Mixed Linen+NNX Model
|
hard
|
primitives |
|
|
|
NNX PartitionSpec Layout
|
medium
|
primitives |
|
|
|
NNX Surgery Zero Layer
|
medium
|
primitives |
|
|
|
NNX Coexistence: Parity With Linen
|
medium
|
primitives |
|
|
|
NNX Data-Parallel Step
|
hard
|
primitives |
|
|
|
NNX Debug Callback
|
medium
|
primitives |
|
|
|
NNX Debug Print
|
medium
|
primitives |
|
|
|
NNX FSDP-Style Step
|
hard
|
primitives |
|
|
|
NNX Distributed Train Step
|
hard
|
primitives |
|
|
|
NNX Graphdef vs State Debug
|
medium
|
primitives |
|
|
|
NNX Orbax Sharded Load
|
medium
|
primitives |
|
|
|
NNX Port: Linen Transformer Block โ NNX
|
hard
|
primitives |
|
|
|
NNX Mesh Init Simulation
|
medium
|
primitives |
|
|
|
NNX Mini-LM Capstone โ Putting It All Together
|
hard
|
primitives |
|
|
|
NNX Port: Linen Dense โ NNX Linear
|
medium
|
primitives |
|
|
|
NNX Tied I/O Embed
|
medium
|
primitives |
|
|
|
NNX Port: Linen MLP โ NNX MLP
|
medium
|
primitives |
|
|
|
NNX Surgery Add Layer
|
medium
|
primitives |
|
|
|
NNX shard_map Simulation
|
hard
|
primitives |
|
|
|
NNX Surgery Freeze
|
hard
|
primitives |
|
|
|
NNX Tensor-Parallel Linear
|
hard
|
primitives |
|