Transformer Architecture Deep Dive
GPU sandbox · jupyter
Beta

Transformer Architecture Deep Dive

Build every piece of a decoder-only transformer by hand — scaled dot-product attention, multi-head attention, the full block with residuals and LayerNorm, then assemble a tiny GPT and train it. No shortcuts, no pre-built attention modules.

50 min·4 steps·3 domains·Advanced·ncp-genlnca-genlncp-adsnca-genm

What you'll learn

  1. 1
    Scaled dot-product attention
  2. 2
    Multi-Head Attention
  3. 3
    The Transformer Block
  4. 4
    Build a tiny GPT, train it, generate

Prerequisites

  • Strong PyTorch fundamentals and nn.Module
  • Linear algebra: matrix multiplication, softmax
  • Comfortable with attention intuition (Q, K, V)

Exam domains covered

LLM Integration and DevelopmentGPU Acceleration & Distributed TrainingExperimentation

Skills & technologies you'll practice

This advanced-level gpu lab gives you real-world reps across:

TransformerAttentionMulti-Head AttentionCausal MaskGPTPyTorchResidual ConnectionsLayerNorm

What you'll build in this transformer-from-scratch lab

Every frontier LLM in 2026 — Llama 3, Mistral, Qwen, GPT-4 — is a decoder-only transformer plus a short list of named upgrades. If you can't build the vanilla version from scratch, those upgrades (RMSNorm, RoPE, SwiGLU, grouped-query attention) stay incantations. In 50 minutes you'll write every piece by hand with no nn.MultiheadAttention or nn.TransformerEncoder shortcuts, train the result on a short repeating pattern until loss drops by 30%+ and it produces coherent output, and walk away with a concrete mental model of why each modern architectural choice exists — why pre-norm replaced post-norm in GPT-3, why sqrt(d_k) matters for gradient flow, why the causal mask goes before softmax not after, and why nn.Linear(D, 3*D) fuses the Q/K/V projections into one BLAS call.

The technical substance is the four pieces plus the geometry traps that catch everyone the first time. Scaled dot-product attention is Q @ K.transpose(-2, -1) / sqrt(d_k) through masked_fill(..., float('-inf')) above the diagonal, softmax(dim=-1), then multiply by V — the checker asserts triu(attention_weights, diagonal=1).abs().max() < 1e-6 to catch mask leakage, and row sums must equal 1.0 to catch misplaced softmax. Multi-head attention reshapes Q/K/V to (B, n_heads, T, d_head) with d_head = D / n_heads and runs your step-1 attention per head in parallel. The GPT-style block uses pre-norm residuals (x = x + MHA(LN(x)) then x = x + MLP(LN(x))) with the MLP as Linear(D, 4D) → GELU → Linear(4D, D) — the checker forces a backward pass to confirm gradient flows through the residual stream. The final tiny GPT stacks token + positional embeddings → N blocks → LayerNorm → Linear(D, vocab_size) and trains with AdamW + cross-entropy. Once you've written it, the Llama-style swaps are surgical: RMSNorm drops LayerNorm's mean-centering to save a reduction per token, RoPE rotates Q and K in 2D pairs so position becomes relative and extrapolates past training length, SwiGLU adds a gated activation branch that learns richer features per parameter, and grouped-query attention shares K/V projections across query-head clusters to shrink the KV cache during inference.

Prerequisites are strong PyTorch fundamentals (nn.Module, autograd, broadcasting), linear algebra comfort (matmul, softmax), and an intuition for Q/K/V — you don't need prior transformer-implementation experience; the hints walk through the exact tensor shapes. The sandbox is a real NVIDIA GPU pod we provision per session, with PyTorch preinstalled. Grading is mechanical and strict throughout: output shape must be (B=2, T=8, D=16) with per-row sum = 1.0 and upper-triangular attention weights strictly zero; a MultiHeadAttention class and mha instance must exist with n_heads > 1 and D % n_heads == 0; the full block must pass a smoke-test of shape preservation plus successful backward; training must log ≥20 steps with >30% loss reduction between first-5 and last-5 averages plus a non-empty sampled string.

Frequently asked questions

Why divide by sqrt(d_k) in scaled dot-product attention?

Because without it, the magnitudes of Q @ K.T grow proportionally to d_k (roughly — it's a sum of d_k products of unit-variance values). Large pre-softmax scores push softmax into saturation, where gradients effectively vanish for all but the largest entry — attention collapses to "pick one token, zero out everything else" and backprop can't learn. Dividing by sqrt(d_k) keeps the variance of the pre-softmax logits stable across head sizes, which keeps the softmax in its well-behaved region and the gradient flowing.

Why pre-norm (x + MHA(LN(x))) instead of post-norm LN(x + MHA(x))?

Because pre-norm has materially better gradient flow at depth. GPT-2 used post-norm and needed careful warmup plus gradient clipping to train stably; GPT-3 moved to pre-norm and every major modern decoder-only architecture (Llama, Mistral, Qwen, Gemma) kept it. With pre-norm, the residual path is a clean identity stream that each block adds an update to, so gradients can flow through untouched from the final layer back to embeddings. Post-norm interleaves normalisation with the residual addition, which subtly attenuates gradient magnitudes per layer and gets worse the deeper you stack.

Why apply the causal mask with -inf instead of just zeroing out the upper triangle after softmax?

Because softmax is non-linear: zeroing post-softmax would renormalise the remaining entries to something that's no longer a proper probability distribution, and gradients would still flow through the zeroed positions during backprop. masked_fill(mask, float('-inf')) applied BEFORE the softmax sends those logits to a value that exponentiates to exactly 0, which softmax's denominator safely ignores, and the upper triangle gets no gradient contribution because d/dx exp(-inf) = 0. The Step 1 checker enforces this by asserting triu(attention_weights, diagonal=1).abs().max() < 1e-6 — if you masked post-softmax, that would fail on numerical rounding.

Why is the MLP's inner dimension 4D?

Mostly empirical tradition carried from "Attention is All You Need" — 4× expansion gives the MLP enough capacity to do meaningful non-linear transformation per block without making the parameter count of the feedforward path dominate the model. Modern architectures revisit this: SwiGLU uses roughly 8D/3 with a gate, Llama 3 uses ~14336 for a 4096-dim model (also around 3.5D to offset the extra SwiGLU parameters), etc. The exact multiplier is a tuning knob, but 4× remains the default for vanilla GELU-based blocks.

Why use a single nn.Linear(D, 3*D) for Q/K/V instead of three separate linears?

Performance. A single linear projection fuses three matrix multiplications into one BLAS call, which is noticeably faster on GPU than three sequential matmuls — the kernel launch overhead disappears and the math gets a slightly larger batched GEMM. The output is then sliced or chunked into Q/K/V along the last dimension. Functionally it's identical to three separate linears with the same weights; it's purely a compute-efficiency refactor, and every serious production implementation does it.

What's the smallest architectural change that turns this from GPT-style into Llama-style?

Four surgical swaps, in order of how much they change: replace nn.LayerNorm with RMSNorm (drop mean-centering and bias, keep the scaling parameter), swap the learned positional embedding in the embedding lookup for RoPE rotation applied to Q and K inside multi-head attention, replace Linear(D, 4D) → GELU → Linear(4D, D) with SwiGLU (Linear(D, 2 * hidden) → split into gate and up, silu(gate) * up → Linear(hidden, D)), and switch vanilla MHA to grouped-query attention by projecting fewer K/V heads than Q heads and broadcasting. Each change is 10-30 lines of code against the lab's scaffolding, and each has a cleanly measurable win on the scaling-law / inference-cost axis it targets.