1
DeepSpeed Storage Deep Dive
🚨 Why DeepSpeed Matters
If you're training models over 10B parameters, you're almost certainly using DeepSpeed (or FSDP). DeepSpeed's ZeRO (Zero Redundancy Optimizer) stages fundamentally change how storage is accessed. Understanding this is essential.
ZeRO Stages and Storage Impact
JSON
// ds_config_70b.json - Production tested for 70B model training { "train_batch_size": 2048, "train_micro_batch_size_per_gpu": 2, "gradient_accumulation_steps": 128, "zero_optimization": { "stage": 3, // Offload optimizer states to NVMe "offload_optimizer": { "device": "nvme", "nvme_path": "/mnt/nvme_raid0", "buffer_count": 5, "buffer_size": 2147483648, // 2GB buffers "fast_init": true, "pin_memory": true }, // Offload parameters to NVMe (for very large models) "offload_param": { "device": "nvme", "nvme_path": "/mnt/nvme_raid0", "buffer_count": 5, "buffer_size": 2147483648, "max_in_cpu": 10000000000 // 10GB CPU buffer first }, // Async I/O configuration - CRITICAL for performance "aio": { "block_size": 1048576, // 1MB blocks (NVMe optimal) "queue_depth": 32, // Match SSD queue depth "thread_count": 8, // Parallel I/O threads "single_submit": false, // Batch submissions "overlap_events": true, // Async overlap "use_gds": true // Enable GPUDirect Storage! }, "contiguous_gradients": true, "overlap_comm": true, "reduce_scatter": true }, "fp16": { "enabled": true } }
Storage Requirements by Model Size
| Model Size | ZeRO Stage | NVMe Capacity | NVMe Bandwidth | Checkpoint |
|---|---|---|---|---|
| 7B (LLaMA) | ZeRO-2 | ~100 GB | 2-3 GB/s | ~14 GB |
| 13B | ZeRO-2/3 | ~200 GB | 3-5 GB/s | ~26 GB |
| 33B | ZeRO-3 | ~500 GB | 5-8 GB/s | ~66 GB |
| 65B/70B | ZeRO-3 + Offload | ~2 TB | 8-12 GB/s | ~140 GB |
| 175B (GPT-3) | ZeRO-Infinity | ~5 TB | 15+ GB/s | ~350 GB |
DeepSpeed NVMe Filesystem Setup
Bash
# Optimal filesystem configuration for DeepSpeed NVMe offload # 1. Create RAID0 across all NVMe drives (8x for H100 node) sudo mdadm --create /dev/md0 --level=0 --raid-devices=8 \ /dev/nvme0n1 /dev/nvme1n1 /dev/nvme2n1 /dev/nvme3n1 \ /dev/nvme4n1 /dev/nvme5n1 /dev/nvme6n1 /dev/nvme7n1 # 2. Format with XFS (better for large files than ext4) sudo mkfs.xfs -f -d su=1m,sw=8 -l su=1m /dev/md0 # 3. Mount with optimal options sudo mkdir -p /mnt/nvme_raid0 sudo mount -o noatime,nodiratime,logbufs=8,logbsize=256k /dev/md0 /mnt/nvme_raid0 # 4. Verify performance (should be 50+ GB/s with 8 drives) fio --name=seq_write --ioengine=libaio --direct=1 --rw=write \ --bs=1m --numjobs=8 --size=10G --directory=/mnt/nvme_raid0 # 5. Pre-create DeepSpeed offload directory mkdir -p /mnt/nvme_raid0/deepspeed_offload
2
Megatron-LM Storage Architecture
⚡ Megatron Insight
Megatron-LM (NVIDIA's LLM framework) takes a different approach than DeepSpeed. It uses memory-mapped files (
mmap) for training data, which leverages the kernel's page cache. This works well for sequential training access but requires careful data preprocessing.
Megatron Data Pipeline
Bash
# Megatron-LM data preprocessing and training setup # Step 1: Preprocess data into Megatron format python tools/preprocess_data.py \ --input /data/raw/corpus.json \ --output-prefix /nvme/megatron/gpt2_train \ --vocab-file /models/gpt2-vocab.json \ --merge-file /models/gpt2-merges.txt \ --dataset-impl mmap \ --tokenizer-type GPT2BPETokenizer \ --workers 96 # Step 2: Launch training with optimal I/O settings python -m torch.distributed.launch \ --nproc_per_node 8 --nnodes 4 \ pretrain_gpt.py \ --tensor-model-parallel-size 4 \ --pipeline-model-parallel-size 2 \ --num-layers 96 --hidden-size 12288 \ --data-path /nvme/megatron/gpt2_train_text_document \ --data-impl mmap \ --save /checkpoints/gpt-70b \ --async-save # Enable huge pages for better mmap performance sudo sysctl -w vm.nr_hugepages=4096
Megatron vs DeepSpeed Comparison
Megatron-LM
- Data access: mmap (page cache)
- Best for: Sequential training data
- Checkpoint: Distributed per TP×PP rank
- NVMe role: Training data + checkpoints
- Memory: Uses system RAM cache
- Pro: Simple, well-tested at scale
- Con: Less flexible than DeepSpeed
DeepSpeed
- Data access: Direct I/O or GDS
- Best for: Memory-constrained training
- Checkpoint: Universal format, flexible
- NVMe role: Optimizer offload + checkpoints
- Memory: Explicit offload control
- Pro: Train larger models with less GPU
- Con: More complex configuration
3
PyTorch FSDP Storage Patterns
PyTorch's Fully Sharded Data Parallel (FSDP) is the native PyTorch answer to DeepSpeed ZeRO. Understanding its storage patterns is essential for PyTorch-native training pipelines.
FSDP Checkpoint Strategies
Python
import torch from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig, ShardedStateDictConfig, ) from torch.distributed.checkpoint import FileSystemWriter, save class FSDPCheckpointer: """Optimized FSDP checkpointing for large models""" def __init__(self, checkpoint_dir, model): self.checkpoint_dir = checkpoint_dir self.model = model def save_sharded(self, step, optimizer=None): """Fast sharded save - each rank saves its shard""" with FSDP.state_dict_type( self.model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=True) ): state_dict = { 'model': self.model.state_dict(), 'step': step, } if optimizer: state_dict['optimizer'] = FSDP.optim_state_dict( self.model, optimizer ) # Distributed save - each rank writes in parallel save( state_dict, checkpoint_id=f"{self.checkpoint_dir}/step_{step}", storage_writer=FileSystemWriter( f"{self.checkpoint_dir}/step_{step}", single_file_per_rank=True, # Better I/O sync_files=False, # Async write thread_count=4, # Parallel writes ) ) def save_consolidated(self, step): """Slow but portable - full model on rank 0""" with FSDP.state_dict_type( self.model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) ): if torch.distributed.get_rank() == 0: state_dict = self.model.state_dict() torch.save(state_dict, f"{self.checkpoint_dir}/consolidated_{step}.pt") torch.distributed.barrier() # Usage pattern checkpointer = FSDPCheckpointer("/mnt/nvme/checkpoints", model) for step, batch in enumerate(dataloader): loss = train_step(model, batch) if step % 500 == 0: checkpointer.save_sharded(step, optimizer) # Fast if step % 5000 == 0: checkpointer.save_consolidated(step) # Portable
4
Distributed Training I/O Patterns
Data Parallel vs Model Parallel Storage
| Parallelism Type | Data Access Pattern | Storage Bottleneck | Optimization Strategy |
|---|---|---|---|
| Data Parallel (DDP) | Each GPU reads different data | Embarrassingly parallel | Pre-shard data per rank |
| Tensor Parallel (TP) | All TP ranks read same data | Checkpoint consolidation | Rank 0 reads, broadcast |
| Pipeline Parallel (PP) | Stage 0 reads, others compute | Stage 0 bottleneck | Double buffering in stage 0 |
| FSDP/ZeRO | All ranks read same batch | Checkpoint write amplification | Sharded checkpoints |
| Expert Parallel (MoE) | Random expert activation | Unpredictable access | Expert caching, prefetch |
Checkpoint Frequency Guidelines
Python
def calculate_checkpoint_interval( model_size_gb, training_throughput_samples_per_sec, checkpoint_bandwidth_gb_per_sec, acceptable_rework_hours=0.5 ): """ Calculate optimal checkpoint interval based on: - Time to save checkpoint - Cost of rework if crash happens - Training throughput """ # Time to write checkpoint checkpoint_time_seconds = model_size_gb / checkpoint_bandwidth_gb_per_sec # Samples processed in acceptable rework time acceptable_rework_samples = ( training_throughput_samples_per_sec * acceptable_rework_hours * 3600 ) # Ensure checkpoint overhead < 5% of training time min_interval_seconds = checkpoint_time_seconds * 20 min_interval_samples = min_interval_seconds * training_throughput_samples_per_sec return max(acceptable_rework_samples, min_interval_samples) # Example: 70B model training interval = calculate_checkpoint_interval( model_size_gb=140, # 70B params × 2 bytes training_throughput_samples_per_sec=100, checkpoint_bandwidth_gb_per_sec=10, # RAID NVMe acceptable_rework_hours=0.5 ) print(f"Checkpoint every {interval:,.0f} samples") # Output: Checkpoint every 180,000 samples (~30 min)
5
Framework-Specific Storage Patterns
Key Insight
Every ML framework does the same underlying operations: stream large sequential reads for training data, random reads for inference, and massive checkpoint writes. The storage doesn't care if it's PyTorch or TensorFlow. What matters is understanding when they access storage and how much they buffer.
PyTorch DataLoader with GDS
Python
import torch from torch.utils.data import DataLoader, Dataset import kvikio import cupy as cp class GPUDirectDataset(Dataset): """Dataset that uses GDS for direct GPU loading""" def __init__(self, file_paths, chunk_size=64*1024*1024): self.file_paths = file_paths self.chunk_size = chunk_size # Pre-allocate GPU buffer pool self.buffer_pool = [cp.empty(chunk_size, dtype=cp.uint8) for _ in range(4)] self.pool_idx = 0 def __getitem__(self, idx): buf = self.buffer_pool[self.pool_idx % len(self.buffer_pool)] self.pool_idx += 1 with kvikio.CuFile(self.file_paths[idx], "r") as f: bytes_read = f.read(buf) return torch.as_tensor(buf[:bytes_read], device='cuda') # Optimal DataLoader configuration for GDS loader = DataLoader( GPUDirectDataset(file_list), batch_size=32, num_workers=0, # CRITICAL: 0 workers for GDS pin_memory=False, # Not needed - already on GPU )
⚠️ Critical PyTorch Pitfall
Using
num_workers > 0 with GDS will silently fall back to CPU bounce buffers. Each worker creates a new process without GPU context inheritance.
TensorFlow tf.data Pipeline
Python
import tensorflow as tf def create_optimized_dataset(file_pattern, batch_size=32): files = tf.data.Dataset.list_files(file_pattern, shuffle=True) dataset = files.interleave( lambda x: tf.data.TFRecordDataset( x, compression_type='', # No compression - let NVMe fly buffer_size=256*1024*1024, # 256MB buffer per file num_parallel_reads=4 # Match NVMe queue depth ), cycle_length=16, num_parallel_calls=tf.data.AUTOTUNE, ) return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
JAX/Flax Storage Patterns
Python
import jax import orbax.checkpoint as ocp # Configure checkpointer for large model shards options = ocp.CheckpointManagerOptions( max_to_keep=3, save_interval_steps=1000, # Enable async checkpointing for overlap with training enable_async_checkpointing=True, ) checkpoint_manager = ocp.CheckpointManager( '/nvme/checkpoints', # Use local NVMe for speed options=options, ) # Save sharded checkpoint (each device saves its shard) checkpoint_manager.save(step, args=ocp.args.StandardSave(state))
💡 JAX Tip
JAX's functional approach means checkpoint data is typically larger (no in-place updates). Use
orbax with async checkpointing to overlap I/O with computation.