Main A: GPU B: NVMe C: Production
C.2 • Framework Integration

ML Framework Integration

Deep integration patterns for DeepSpeed, Megatron-LM, PyTorch FSDP, and distributed training storage optimization.

1

DeepSpeed Storage Deep Dive

🚨 Why DeepSpeed Matters If you're training models over 10B parameters, you're almost certainly using DeepSpeed (or FSDP). DeepSpeed's ZeRO (Zero Redundancy Optimizer) stages fundamentally change how storage is accessed. Understanding this is essential.

ZeRO Stages and Storage Impact

Model Training State Parameters (P) ~4B per 1B params 70B Model: 280 GB Gradients (G) ~4B per 1B params 70B Model: 280 GB Optimizer State (OS) ~12B per 1B (Adam: m,v,master) 70B Model: 840 GB Total: ~1.4 TB per GPU! ZeRO Stages: ZeRO-1: Partition OS across GPUs → OS per GPU: 840/N GB ZeRO-2: Partition OS + G → (OS+G) per GPU: 1120/N GB ZeRO-3: Partition OS + G + P → Total per GPU: 1400/N GB ZeRO-Offload/Infinity: GPU exhausted → Offload to NVMe
JSON
// ds_config_70b.json - Production tested for 70B model training
{
    "train_batch_size": 2048,
    "train_micro_batch_size_per_gpu": 2,
    "gradient_accumulation_steps": 128,
    
    "zero_optimization": {
        "stage": 3,
        
        // Offload optimizer states to NVMe
        "offload_optimizer": {
            "device": "nvme",
            "nvme_path": "/mnt/nvme_raid0",
            "buffer_count": 5,
            "buffer_size": 2147483648,  // 2GB buffers
            "fast_init": true,
            "pin_memory": true
        },
        
        // Offload parameters to NVMe (for very large models)
        "offload_param": {
            "device": "nvme",
            "nvme_path": "/mnt/nvme_raid0",
            "buffer_count": 5,
            "buffer_size": 2147483648,
            "max_in_cpu": 10000000000  // 10GB CPU buffer first
        },
        
        // Async I/O configuration - CRITICAL for performance
        "aio": {
            "block_size": 1048576,     // 1MB blocks (NVMe optimal)
            "queue_depth": 32,         // Match SSD queue depth
            "thread_count": 8,         // Parallel I/O threads
            "single_submit": false,    // Batch submissions
            "overlap_events": true,    // Async overlap
            "use_gds": true            // Enable GPUDirect Storage!
        },
        
        "contiguous_gradients": true,
        "overlap_comm": true,
        "reduce_scatter": true
    },
    
    "fp16": {
        "enabled": true
    }
}

Storage Requirements by Model Size

Model Size ZeRO Stage NVMe Capacity NVMe Bandwidth Checkpoint
7B (LLaMA) ZeRO-2 ~100 GB 2-3 GB/s ~14 GB
13B ZeRO-2/3 ~200 GB 3-5 GB/s ~26 GB
33B ZeRO-3 ~500 GB 5-8 GB/s ~66 GB
65B/70B ZeRO-3 + Offload ~2 TB 8-12 GB/s ~140 GB
175B (GPT-3) ZeRO-Infinity ~5 TB 15+ GB/s ~350 GB

DeepSpeed NVMe Filesystem Setup

Bash
# Optimal filesystem configuration for DeepSpeed NVMe offload

# 1. Create RAID0 across all NVMe drives (8x for H100 node)
sudo mdadm --create /dev/md0 --level=0 --raid-devices=8 \
    /dev/nvme0n1 /dev/nvme1n1 /dev/nvme2n1 /dev/nvme3n1 \
    /dev/nvme4n1 /dev/nvme5n1 /dev/nvme6n1 /dev/nvme7n1

# 2. Format with XFS (better for large files than ext4)
sudo mkfs.xfs -f -d su=1m,sw=8 -l su=1m /dev/md0

# 3. Mount with optimal options
sudo mkdir -p /mnt/nvme_raid0
sudo mount -o noatime,nodiratime,logbufs=8,logbsize=256k /dev/md0 /mnt/nvme_raid0

# 4. Verify performance (should be 50+ GB/s with 8 drives)
fio --name=seq_write --ioengine=libaio --direct=1 --rw=write \
    --bs=1m --numjobs=8 --size=10G --directory=/mnt/nvme_raid0

# 5. Pre-create DeepSpeed offload directory
mkdir -p /mnt/nvme_raid0/deepspeed_offload
2

Megatron-LM Storage Architecture

⚡ Megatron Insight Megatron-LM (NVIDIA's LLM framework) takes a different approach than DeepSpeed. It uses memory-mapped files (mmap) for training data, which leverages the kernel's page cache. This works well for sequential training access but requires careful data preprocessing.

Megatron Data Pipeline

Data Preprocessing (Offline) Raw Text corpus.json Tokenization GPT2 BPE Binary Format train.bin (mmap) Index File train.idx Training Data Loading mmap() train.bin (NVMe) Page Cache Linux (DRAM) GPU Memory Training Batch (HBM) Sequential access = good prefetch = high throughput
Bash
# Megatron-LM data preprocessing and training setup

# Step 1: Preprocess data into Megatron format
python tools/preprocess_data.py \
    --input /data/raw/corpus.json \
    --output-prefix /nvme/megatron/gpt2_train \
    --vocab-file /models/gpt2-vocab.json \
    --merge-file /models/gpt2-merges.txt \
    --dataset-impl mmap \
    --tokenizer-type GPT2BPETokenizer \
    --workers 96

# Step 2: Launch training with optimal I/O settings
python -m torch.distributed.launch \
    --nproc_per_node 8 --nnodes 4 \
    pretrain_gpt.py \
    --tensor-model-parallel-size 4 \
    --pipeline-model-parallel-size 2 \
    --num-layers 96 --hidden-size 12288 \
    --data-path /nvme/megatron/gpt2_train_text_document \
    --data-impl mmap \
    --save /checkpoints/gpt-70b \
    --async-save

# Enable huge pages for better mmap performance
sudo sysctl -w vm.nr_hugepages=4096

Megatron vs DeepSpeed Comparison

Megatron-LM

  • Data access: mmap (page cache)
  • Best for: Sequential training data
  • Checkpoint: Distributed per TP×PP rank
  • NVMe role: Training data + checkpoints
  • Memory: Uses system RAM cache
  • Pro: Simple, well-tested at scale
  • Con: Less flexible than DeepSpeed

DeepSpeed

  • Data access: Direct I/O or GDS
  • Best for: Memory-constrained training
  • Checkpoint: Universal format, flexible
  • NVMe role: Optimizer offload + checkpoints
  • Memory: Explicit offload control
  • Pro: Train larger models with less GPU
  • Con: More complex configuration
3

PyTorch FSDP Storage Patterns

PyTorch's Fully Sharded Data Parallel (FSDP) is the native PyTorch answer to DeepSpeed ZeRO. Understanding its storage patterns is essential for PyTorch-native training pipelines.

FSDP Checkpoint Strategies

Python
import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    StateDictType,
    FullStateDictConfig,
    ShardedStateDictConfig,
)
from torch.distributed.checkpoint import FileSystemWriter, save

class FSDPCheckpointer:
    """Optimized FSDP checkpointing for large models"""
    
    def __init__(self, checkpoint_dir, model):
        self.checkpoint_dir = checkpoint_dir
        self.model = model
    
    def save_sharded(self, step, optimizer=None):
        """Fast sharded save - each rank saves its shard"""
        
        with FSDP.state_dict_type(
            self.model,
            StateDictType.SHARDED_STATE_DICT,
            ShardedStateDictConfig(offload_to_cpu=True)
        ):
            state_dict = {
                'model': self.model.state_dict(),
                'step': step,
            }
            if optimizer:
                state_dict['optimizer'] = FSDP.optim_state_dict(
                    self.model, optimizer
                )
            
            # Distributed save - each rank writes in parallel
            save(
                state_dict,
                checkpoint_id=f"{self.checkpoint_dir}/step_{step}",
                storage_writer=FileSystemWriter(
                    f"{self.checkpoint_dir}/step_{step}",
                    single_file_per_rank=True,  # Better I/O
                    sync_files=False,  # Async write
                    thread_count=4,  # Parallel writes
                )
            )
    
    def save_consolidated(self, step):
        """Slow but portable - full model on rank 0"""
        
        with FSDP.state_dict_type(
            self.model,
            StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        ):
            if torch.distributed.get_rank() == 0:
                state_dict = self.model.state_dict()
                torch.save(state_dict, f"{self.checkpoint_dir}/consolidated_{step}.pt")
        
        torch.distributed.barrier()

# Usage pattern
checkpointer = FSDPCheckpointer("/mnt/nvme/checkpoints", model)

for step, batch in enumerate(dataloader):
    loss = train_step(model, batch)
    
    if step % 500 == 0:
        checkpointer.save_sharded(step, optimizer)  # Fast
    
    if step % 5000 == 0:
        checkpointer.save_consolidated(step)  # Portable
4

Distributed Training I/O Patterns

Data Parallel vs Model Parallel Storage

Parallelism Type Data Access Pattern Storage Bottleneck Optimization Strategy
Data Parallel (DDP) Each GPU reads different data Embarrassingly parallel Pre-shard data per rank
Tensor Parallel (TP) All TP ranks read same data Checkpoint consolidation Rank 0 reads, broadcast
Pipeline Parallel (PP) Stage 0 reads, others compute Stage 0 bottleneck Double buffering in stage 0
FSDP/ZeRO All ranks read same batch Checkpoint write amplification Sharded checkpoints
Expert Parallel (MoE) Random expert activation Unpredictable access Expert caching, prefetch

Checkpoint Frequency Guidelines

Python
def calculate_checkpoint_interval(
    model_size_gb,
    training_throughput_samples_per_sec,
    checkpoint_bandwidth_gb_per_sec,
    acceptable_rework_hours=0.5
):
    """
    Calculate optimal checkpoint interval based on:
    - Time to save checkpoint
    - Cost of rework if crash happens
    - Training throughput
    """
    
    # Time to write checkpoint
    checkpoint_time_seconds = model_size_gb / checkpoint_bandwidth_gb_per_sec
    
    # Samples processed in acceptable rework time
    acceptable_rework_samples = (
        training_throughput_samples_per_sec * acceptable_rework_hours * 3600
    )
    
    # Ensure checkpoint overhead < 5% of training time
    min_interval_seconds = checkpoint_time_seconds * 20
    min_interval_samples = min_interval_seconds * training_throughput_samples_per_sec
    
    return max(acceptable_rework_samples, min_interval_samples)

# Example: 70B model training
interval = calculate_checkpoint_interval(
    model_size_gb=140,           # 70B params × 2 bytes
    training_throughput_samples_per_sec=100,
    checkpoint_bandwidth_gb_per_sec=10,  # RAID NVMe
    acceptable_rework_hours=0.5
)
print(f"Checkpoint every {interval:,.0f} samples")
# Output: Checkpoint every 180,000 samples (~30 min)
5

Framework-Specific Storage Patterns

Key Insight Every ML framework does the same underlying operations: stream large sequential reads for training data, random reads for inference, and massive checkpoint writes. The storage doesn't care if it's PyTorch or TensorFlow. What matters is understanding when they access storage and how much they buffer.

PyTorch DataLoader with GDS

Python
import torch
from torch.utils.data import DataLoader, Dataset
import kvikio
import cupy as cp

class GPUDirectDataset(Dataset):
    """Dataset that uses GDS for direct GPU loading"""
    
    def __init__(self, file_paths, chunk_size=64*1024*1024):
        self.file_paths = file_paths
        self.chunk_size = chunk_size
        # Pre-allocate GPU buffer pool
        self.buffer_pool = [cp.empty(chunk_size, dtype=cp.uint8) 
                          for _ in range(4)]
        self.pool_idx = 0
    
    def __getitem__(self, idx):
        buf = self.buffer_pool[self.pool_idx % len(self.buffer_pool)]
        self.pool_idx += 1
        
        with kvikio.CuFile(self.file_paths[idx], "r") as f:
            bytes_read = f.read(buf)
        
        return torch.as_tensor(buf[:bytes_read], device='cuda')

# Optimal DataLoader configuration for GDS
loader = DataLoader(
    GPUDirectDataset(file_list),
    batch_size=32,
    num_workers=0,        # CRITICAL: 0 workers for GDS
    pin_memory=False,    # Not needed - already on GPU
)
⚠️ Critical PyTorch Pitfall Using num_workers > 0 with GDS will silently fall back to CPU bounce buffers. Each worker creates a new process without GPU context inheritance.

TensorFlow tf.data Pipeline

Python
import tensorflow as tf

def create_optimized_dataset(file_pattern, batch_size=32):
    files = tf.data.Dataset.list_files(file_pattern, shuffle=True)
    
    dataset = files.interleave(
        lambda x: tf.data.TFRecordDataset(
            x,
            compression_type='',           # No compression - let NVMe fly
            buffer_size=256*1024*1024,  # 256MB buffer per file
            num_parallel_reads=4          # Match NVMe queue depth
        ),
        cycle_length=16,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    
    return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

JAX/Flax Storage Patterns

Python
import jax
import orbax.checkpoint as ocp

# Configure checkpointer for large model shards
options = ocp.CheckpointManagerOptions(
    max_to_keep=3,
    save_interval_steps=1000,
    # Enable async checkpointing for overlap with training
    enable_async_checkpointing=True,
)

checkpoint_manager = ocp.CheckpointManager(
    '/nvme/checkpoints',  # Use local NVMe for speed
    options=options,
)

# Save sharded checkpoint (each device saves its shard)
checkpoint_manager.save(step, args=ocp.args.StandardSave(state))
💡 JAX Tip JAX's functional approach means checkpoint data is typically larger (no in-place updates). Use orbax with async checkpointing to overlap I/O with computation.