← All tracks

Flax

Modules, layers from scratch, attention, transformer architectures, training loops, lifted transforms. Production model code in JAX (Linen API).

0 / 100 solved Continue →
  1. 1. Module with @nn.compact
  2. 2. Module with setup() (alternative to compact)
  3. 3. Custom Parameter Initializer
  4. 4. init() and apply() Round-Trip
  5. 5. nn.Sequential Composition
  6. 6. Module with Multiple Named Sub-Modules
  7. 7. Module That Branches on a Config Flag
  8. 8. Three Levels of Module Nesting
  9. 9. Multiple PRNG Streams (params, dropout)
  10. 10. Train vs Eval Branches via train Flag
  11. 11. Implement Dense from Scratch
  12. 12. Implement Conv1D from Scratch
  13. 13. Implement Conv2D with Stride and Padding
  14. 14. Implement Transposed Convolution
  15. 15. Implement Depthwise-Separable Convolution
  16. 16. Implement LayerNorm with ฮณ/ฮฒ
  17. 17. Implement BatchNorm with Mutable batch_stats
  18. 18. Implement GroupNorm
  19. 19. Implement RMSNorm (Modern LLM Norm)
  20. 20. Implement Dropout with RNG Threading
  21. 21. Scaled Dot-Product Attention
  22. 22. Multi-Head Self-Attention with Flax
  23. 23. Causal Multi-Head Self-Attention
  24. 24. Cross-Attention with Flax MHA
  25. 25. Multi-Head Attention with KV Cache
  26. 26. Grouped-Query Attention (GQA)
  27. 27. Multi-Query Attention (MQA)
  28. 28. Sliding-Window Attention (Mistral-style)
  29. 29. ALiBi: Attention with Linear Biases
  30. 30. Block-Diagonal Attention Mask
  31. 31. Token Embedding with Flax
  32. 32. Sinusoidal Position Encoding
  33. 33. Learned Position Embedding
  34. 34. Rotary Position Embedding (RoPE)
  35. 35. ALiBi Bias Matrix
  36. 36. T5 Relative Position Bucketing
  37. 37. Tied Input/Output Embedding
  38. 38. ViT Patch Embedding
  39. 39. Transformer Encoder Block (Pre-LN)
  40. 40. Transformer Decoder Block (Pre-LN)
  41. 41. Pre-LN vs Post-LN Residual Pattern
  42. 42. Mini GPT โ€” Decoder-Only Language Model
  43. 43. Mini BERT โ€” Encoder-Only Hidden States
  44. 44. Mini T5 โ€” Encoder-Decoder with RMSNorm and Tied Embeddings
  45. 45. Vision Transformer (Mean-Pool Variant)
  46. 46. Vision Transformer with [CLS] Token
  47. 47. DeiT โ€” Data-Efficient Image Transformer
  48. 48. SwiGLU Feed-Forward Network
  49. 49. ResNet Basic Block
  50. 50. ResNet Bottleneck Block
  51. 51. Tiny ResNet Classifier
  52. 52. Tiny U-Net
  53. 53. GRU Cell Step
  54. 54. LSTM Cell Step
  55. 55. Bidirectional RNN
  56. 56. Mixture-of-Experts FFN
  57. 57. Squeeze-and-Excitation Block
  58. 58. Vision-Language Fusion
  59. 59. TrainState โ€” One Step
  60. 60. train_step with value_and_grad
  61. 61. eval_step โ€” Forward + Metrics
  62. 62. Label-Smoothed Cross-Entropy
  63. 63. Mixed-Precision Training Step
  64. 64. Train with Mutable batch_stats
  65. 65. Multi-Task Two-Head Loss
  66. 66. Sharded Eval Loss
  67. 67. Warmup-Cosine LR at Step
  68. 68. Gradient Accumulation Step
  69. 69. EMA of Parameters
  70. 70. Orbax Save (Tree-Leaf Count)
  71. 71. Orbax Load (Restore via Template)
  72. 72. HF Weight Load (Kernel Transpose)
  73. 73. Pre-train then Fine-tune (Frozen Trunk)
  74. 74. Per-Param Weight Decay Mask
  75. 75. Per-Param Learning Rate Multipliers
  76. 76. Param Freezing via Grad Zeroing
  77. 77. Test-Time Augmentation Aggregation
  78. 78. Distributed Checkpoint (Sharding Math)
  79. 79. nn.scan over an RNN cell
  80. 80. nn.scan over layers
  81. 81. nn.vmap with shared params
  82. 82. nn.checkpoint (gradient checkpointing)
  83. 83. nn.jit (Flax-aware JIT lift)
  84. 84. nn.remat with checkpoint policies
  85. 85. Composed lifts: nn.scan + nn.vmap
  86. 86. Batched init via jax.vmap
  87. 87. jax.lax.scan inside a Flax Module
  88. 88. Custom lift: roll your own ensemble
  89. 89. PartitionSpec Layout
  90. 90. with_sharding_constraint Annotation
  91. 91. nn.with_partitioning Annotation
  92. 92. flax.struct.dataclass โ€” Pytree-Friendly State
  93. 93. Param Surgery โ€” Kernel Replace
  94. 94. Param Surgery โ€” Zero Last Layer
  95. 95. Param Surgery โ€” Freeze First Dense
  96. 96. Partial Init โ€” Warm-Start From Smaller Checkpoint
  97. 97. Multiple Mutable Collections
  98. 98. Param Sharing โ€” One Module, Two Call Sites
  99. 99. shard_map Simulation โ€” Manual SPMD
  100. 100. Mini LM Capstone โ€” Putting It All Together