The Complete Journey

Trace a single PyTorch distributed training call through every layer of the GPU computing stack — from Python to silicon

# Your PyTorch Distributed Training Code
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# Initialize process group
dist.init_process_group(backend="nccl")  # ← NCCL for GPU communication

# Wrap model with DDP or FSDP
model = FSDP(model)  ← This single line triggers the ENTIRE stack below

# Training loop
for batch in dataloader:
    outputs = model(batch)           # Forward: AllGather params
    loss = criterion(outputs, labels)
    loss.backward()                  # Backward: ReduceScatter grads
    optimizer.step()                  # Update sharded optimizer state

PyTorch Framework

Python → C++ → CUDA Kernels
"What happens when I call model.forward()?"
model(batch) ↓ Python dispatch torch._C._nn.linear() ↓ ATen operators at::native::linear() ↓ CUDA dispatch cublas_gemm()

What Happens

  • Python tensors → C++ ATen tensors
  • Dynamic dispatch to CUDA kernels
  • Autograd graph construction
  • Memory allocation via caching allocator
🔥

Distributed Wrapper

DDP / FSDP / DeepSpeed
FSDP(model) ↓ Shards parameters across GPUs FlatParameter # Flattened shards ↓ Registers hooks pre_forward_hook → AllGather post_backward_hook → ReduceScatter
Memory Reduction
2
Collectives/Layer
ZeRO-3
Equivalent

Parallelism Strategy

DP • TP • PP • ZeRO • Expert
"How is the model split across 16,384 GPUs?"
Llama 3.1 405B Configuration: TP = 8 # Tensor parallel within NVLink node PP = 8 # Pipeline stages across nodes DP = 256 # Data parallel replicas Total: 8 × 8 × 256 = 16,384 GPUs

Strategy Selection

  • DP: AllReduce gradients (1.75× data)
  • TP: AllGather/ReduceScatter activations
  • PP: P2P Send/Recv between stages
  • ZeRO: Shard everything, gather on demand
📚 Curriculum Files
16-distributed-training-202633KB 16-distributed-training-visual92KB 16-distributed-training-animated135KB parallelism-strategies107KB
🧩

Collective Operations

What each strategy needs
FSDP Forward Pass: ↓ Before each layer ncclAllGather(flat_param_shard) # Reconstruct full params from N shards FSDP Backward Pass: ↓ After each layer ncclReduceScatter(gradients) # Reduce + distribute grad shards
367
KB Docs
4
Files
Ch 16
Chapter

Attention Mechanism

Flash Attention • Memory Optimization
"How does self-attention scale to 128K context?"
Standard Attention: O(N²) memory 128K × 128K × 2 bytes = 32GB per head! ↓ Flash Attention tiling Flash Attention: O(N) memory Block size 128 → fits in SRAM

Flash Attention Process

  • Tile Q, K, V into blocks
  • Compute attention in SRAM (19TB/s)
  • Online softmax with running max
  • Never materialize N×N matrix
📚 Curriculum Files
15-flash-attention-202637KB 15-flash-attention-animated77KB

Attention Kernels

Fused CUDA operations
flash_attn_func(q, k, v) ↓ Launches fused kernel flash_fwd_kernel<<>> ↓ Per-block computation Load Q_i, K_j, V_j to SRAM S_ij = Q_i @ K_j.T / sqrt(d) P_ij = softmax(S_ij) O_i += P_ij @ V_j
114
KB Docs
2
Files
Ch 15
Chapter

Linear Algebra (BLAS)

GEMM • cuBLAS • rocBLAS
"How is matrix multiplication actually computed?"
nn.Linear(4096, 4096) ↓ Dispatches to cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, M, N, K, # Dimensions &alpha, A, lda, # Input A B, ldb, # Input B &beta, C, ldc) # Output C

GEMM Optimization

  • Tile for L2 cache (40MB on H100)
  • Use tensor cores for 16×16 blocks
  • Overlap compute with memory access
  • Batched GEMM for attention heads
📚 Curriculum Files
14-blas-communication-202642KB 14-blas-communication123KB
📐

Tensor Core Dispatch

Hardware matrix units
cuBLAS selects kernel: ↓ For FP16/BF16 hmma.16816.f32 # Tensor core MMA ↓ PTX instruction D = A × B + C 16×16×16 in one cycle
165
KB Docs
2
Files
Ch 14
Chapter

GPU Primitives

Warps • Tensor Cores • Sparse
"How do 32 threads work together?"
Warp of 32 threads: ↓ Execute in lockstep __shfl_xor_sync(mask, val, lane) # Exchange data between lanes ↓ Warp-level reduction sum = val + __shfl_down(val, 16) sum += __shfl_down(sum, 8) ... # 5 shuffles = full reduce

Primitive Operations

  • Shuffle: Direct lane-to-lane data exchange
  • Vote: Warp-wide predicates
  • Tensor Core: 16×16 matrix multiply
  • Sparse: Structured sparsity (2:4)
📚 Curriculum Files
12-sparse-matrix-202635KB 12-sparse-matrix-animated83KB 13-warp-primitives-202639KB warp-primitives-animated47KB tensor-core-index18KB tensor-core-library/folder
🔧

Tensor Core Operation

4th Gen on H100
wmma::mma_sync(d, a, b, c); ↓ Hardware execution // Per SM, per cycle: 4 Tensor Cores × 256 FP16 ops = 1024 ops/SM/cycle ↓ Full H100 132 SMs × 1024 × 2 GHz = 989 TFLOPS FP16
261
KB Docs
6
Items
Ch 12-13
Chapters

Software Stack

CUDA • HIP • NCCL • cuDNN
"How does NCCL implement AllReduce?"
ncclAllReduce(sendbuff, recvbuff, count, datatype, ncclSum, comm, stream) ↓ Algorithm selection if (size > 256KB && nRanks <= 8) → Ring algorithm else → Tree algorithm

NCCL Ring AllReduce

  • Split data into N chunks
  • Reduce-Scatter: N-1 steps
  • AllGather: N-1 more steps
  • Result: 2(N-1)/N × bandwidth optimal
📚 Curriculum Files
11-compilers-202643KB cuda-vs-hip44KB cudnn-vs-miopen84KB nccl-vs-rccl167KB
📡

Communication Path

GPU → NVLink → Network
Ring AllReduce on 8 GPUs: ↓ Reduce-Scatter phase GPU0 → GPU1 → GPU2 → ... → GPU7 → GPU0 // Each GPU sends 1/8, receives 1/8 ↓ AllGather phase GPU0 → GPU1 → GPU2 → ... → GPU7 → GPU0 // Each GPU broadcasts its chunk
338
KB Docs
4
Files
Ch 11+
Chapters

Memory Hierarchy

HBM • L2 • Shared • Registers
"Why is memory bandwidth the bottleneck?"
H100 Memory Hierarchy: Registers 256KB/SM ∞ TB/s # Fastest Shared 228KB/SM 19 TB/s L2 Cache 50MB 12 TB/s HBM3 80GB 3.35 TB/s # Slowest

Memory Optimization

  • Coalesced access: 128B per warp
  • Avoid bank conflicts in shared memory
  • Tile to maximize L2 hits
  • Arithmetic intensity > 100 ops/byte
📚 Curriculum Files
11-memory-hierarchy-202625KB
💾

Data Movement

The real cost
GEMM arithmetic intensity: Ops: 2 × M × N × K Data: 2 × (M×K + K×N + M×N) bytes ↓ For 4096×4096×4096 FP16 Ops: 137B Data: 100MB Intensity: 1370 ops/byte ✓
25
KB Docs
1
File
Ch 11
Chapter

Hardware Interconnect

NVLink • NVSwitch • XGMI
"How do GPUs physically communicate?"
DGX H100 (8 GPUs): NVSwitch Fabric ├── GPU0 ←─ 900 GB/s ─→ NVSwitch ├── GPU1 ←─ 900 GB/s ─→ NVSwitch ├── ... └── GPU7 ←─ 900 GB/s ─→ NVSwitch Any-to-any at full bandwidth Total: 7.2 TB/s fabric

Physical Layer

  • NVLink 4.0: 900 GB/s per GPU
  • NVSwitch: Full-mesh topology
  • GPUDirect RDMA: Network → GPU memory
  • InfiniBand: 400 Gb/s cross-node
📚 Curriculum Files
10-gpu-interconnect-202626KB gpu-interconnect122KB chiplet-mcm-architecture48KB 18-gpu-state-of-art-202626KB
🔌

Silicon

The physical foundation
NVIDIA H100 SXM: Die: 814mm², 80B transistors SMs: 132 streaming multiprocessors Cores: 16,896 CUDA + 528 Tensor HBM3: 80GB @ 3.35 TB/s TDP: 700W 989 TFLOPS FP16 Tensor
222
KB Docs
4
Files
Ch 10,18
Chapters

The Complete Picture

📊 Layer Summary

  • ① PyTorch/FrameworkEntry point
  • ② Parallelism Strategy367 KB • 4 files
  • ③ Attention114 KB • 2 files
  • ④ BLAS165 KB • 2 files
  • ⑤ GPU Primitives261 KB • 6 items
  • ⑥ Software Stack338 KB • 4 files
  • ⑦ Memory25 KB • 1 file
  • ⑧ Hardware222 KB • 4 files

🔥 Key Takeaways

  • One FSDP call→ All 8 layers
  • Each forward pass→ AllGather + GEMM
  • Each backward pass→ ReduceScatter
  • Memory is the bottleneck→ Flash Attention
  • NVLink enables TP→ 900 GB/s required

📚 Total Curriculum

  • Total Files24 HTML + 1 folder
  • Total Size~1.85 MB
  • Chapters10, 11, 12, 13, 14, 15, 16, 18
  • Animated Files5
  • Reference Docs9