🔷 FSDP COMPLETE CYCLE

FlatParameter Sharding → Forward → Backward → Update

The complete FSDP training cycle: how parameters flow through sharding, compute, and gradient synchronization

Speed: 1x
1 Flatten
2 Shard ÷4
3 AllGather
4 Forward
5 Backward
6 ReduceScatter
7 Update
Original Parameters layer.weight [4096×4096] 16.7M + bias [4096] 4K Flatten FlatParameter 0 8.4M 16.8M flat_param: [16,781,312] Chunk÷4 Sharded Across 4 GPUs S₀ GPU 0 S₁ GPU 1 S₂ GPU 2 S₃ GPU 3 AllGather AllGather Result (Each GPU has full params - TEMPORARY) S₀+S₁+S₂+S₃ = Full Unflatten Ready for Compute weight bias Forward Forward Pass y = Wx + b activations After forward: 🗑️ Full params freed! Only shards remain loss.backward() AllGather Again (Need full params for grad computation) Full params for backward Backward Pass Compute gradients for all params 🔥 ∇W = ∂L/∂W, ∇b = ∂L/∂b Full grads: [16.8M] ReduceScatter ReduceScatter: Sum gradients + Distribute shards Each GPU gets 1/N of the reduced (averaged) gradients ∇₀ GPU 0 ∇₁ GPU 1 ∇₂ GPU 2 ∇₃ GPU 3 optimizer.step() Local Optimizer Update Each GPU updates its own shard with its grad shard Sₖ = Sₖ - lr × ∇ₖ (local, no comm!) → Ready for next iteration! Communication Summary Per layer, per step: • 2× AllGather (fwd + bwd) • 1× ReduceScatter (grads) ~50% more comm, N× less memory!
Click "Start Full Cycle" to animate the complete FSDP training loop
DDP Memory (Full Replication)
4Ψ per GPU (params+grads+optim)
FSDP Memory (Sharded)
4Ψ/N per GPU (N× reduction!)

🔄 Why AllGather Twice?

Forward AND backward both need full params. After forward, params are freed. Before backward, AllGather runs again. This trades communication for memory.

📉 ReduceScatter Magic

Combines AllReduce + scatter in one op. Sums gradients across GPUs, then each GPU keeps only its shard. Communication = (N-1)/N × size.

🎯 Local Optimizer Step

Each GPU updates ONLY its shard using ONLY its grad shard. No communication needed! Optimizer states (Adam m, v) also sharded = huge memory savings.

⚡ The Trade-off

FSDP: ~50% more communication than DDP, but N× less memory. For 8 GPUs: 8× memory reduction enables training 8× larger models!