The complete FSDP training cycle: how parameters flow through sharding, compute, and gradient synchronization
Forward AND backward both need full params. After forward, params are freed. Before backward, AllGather runs again. This trades communication for memory.
Combines AllReduce + scatter in one op. Sums gradients across GPUs, then each GPU keeps only its shard. Communication = (N-1)/N × size.
Each GPU updates ONLY its shard using ONLY its grad shard. No communication needed! Optimizer states (Adam m, v) also sharded = huge memory savings.
FSDP: ~50% more communication than DDP, but N× less memory. For 8 GPUs: 8× memory reduction enables training 8× larger models!