Overview#
The distributed training ecosystem in 2025 centres on a small number of frameworks, each with distinct strengths. Most production systems use a combination of these tools.
Why this matters for AGIACC: frameworks are where theory meets adoption. Any security company that wants to matter in AI infrastructure has to understand the real integration surface teams actually deploy.
PyTorch Distributed#
PyTorch’s native distributed training stack has matured significantly:
DistributedDataParallel (DDP)#
The standard for single-program-multiple-data training. Handles gradient synchronisation transparently with NCCL or Gloo backends.
Fully Sharded Data Parallel (FSDP)#
Integrates ZeRO-3-style parameter sharding natively. Key features:
- Per-module sharding
- Mixed precision (fp16/bf16)
- Activation checkpointing
- CPU offloading
DeviceMesh + DTensor#
PyTorch 2.x introduced DeviceMesh for expressing multi-dimensional parallelism (TP × PP × DP) as a first-class concept, and DTensor for distributed tensor representations.
torch.compile + Distributed#
Graph compilation via torch.compile can now optimise across distributed boundaries, fusing communication and computation kernels for improved throughput.
Microsoft DeepSpeed#
DeepSpeed is the most feature-rich distributed training library, built on top of PyTorch:
| Feature | Description |
|---|---|
| ZeRO (1/2/3) | Memory-efficient data parallelism |
| ZeRO-Offload / ZeRO-Infinity | Offload optimizer states and parameters to CPU or NVMe SSD |
| Pipeline Parallelism | Built-in 1F1B and interleaved scheduling |
| Sparse Attention | Efficient attention patterns for long sequences |
| MoE Support | Expert parallelism with DeepSpeed-MoE |
| Inference Optimisation | Tensor slicing, kernel injection, quantisation |
DeepSpeed’s zero-code-change integration (via deepspeed.initialize()) makes it accessible for researchers who want advanced features without rewriting their training loops.
NVIDIA Megatron-LM#
Megatron-LM provides the canonical implementation of tensor parallelism and is the foundation of Megatron-DeepSpeed:
- Column/row-parallel linear layers
- Self-attention head parallelism
- Sequence parallelism for non-tensor-parallel regions
- Interleaved pipeline scheduling
- Extensive support for transformer architectures (GPT, BERT, T5)
Megatron-LM is highly optimised but tightly coupled to NVIDIA GPUs and NCCL. It is the backbone of NVIDIA’s NeMo framework.
Ray Train#
Ray Train (part of the Ray ecosystem) provides a higher-level abstraction for distributed training:
- Framework-agnostic — wraps PyTorch, TensorFlow, JAX, and Hugging Face Transformers
- Elastic scaling — dynamically add/remove workers during training
- Fault tolerance — automatic checkpoint/restart on worker failure
- Heterogeneous resources — mix GPU types and generations within a job
- Integration with Ray Tune for hyperparameter search
Ray Train is particularly popular for teams that need to orchestrate training across heterogeneous cloud infrastructure.
JAX / XLA#
Google’s JAX takes a fundamentally different approach: programs are written as pure functions, which are compiled and parallelised by the XLA compiler:
- pmap / pjit — express parallelism declaratively via sharding constraints
- The compiler automatically partitions computation and inserts communication primitives
- Native support for TPU pods (v4, v5e, v6e)
- Used to train Gemini, PaLM, and other Google frontier models
Comparison#
| Framework | TP | PP | DP/ZeRO | MoE | Best For |
|---|---|---|---|---|---|
| PyTorch DDP/FSDP | DeviceMesh | Basic | ✓ | Community | General-purpose, PyTorch-native |
| DeepSpeed | Via Megatron | ✓ | ✓ (ZeRO 1-3) | ✓ | Feature-rich, memory-efficient |
| Megatron-LM | ✓ (canonical) | ✓ | ✓ | ✓ | NVIDIA GPU clusters, frontier scale |
| Ray Train | Via backends | Via backends | ✓ | Via backends | Multi-cloud, elastic, heterogeneous |
| JAX/XLA | Auto (pjit) | Auto | Auto | ✓ | TPU, compiler-driven parallelism |
Next: Security Challenges →