pytorch-distributed
1
总安装量
1
周安装量
#50200
全站排名
安装命令
npx skills add https://github.com/cuba6112/skillfactory --skill pytorch-distributed
Agent 安装分布
mcpjam
1
claude-code
1
junie
1
windsurf
1
zencoder
1
crush
1
Skill 文档
Overview
PyTorch Distributed enables training models across multiple GPUs and nodes. DistributedDataParallel (DDP) is the standard for multi-process data parallelism, while Fully Sharded Data Parallel (FSDP) shards model state to allow training models too large for a single GPU.
When to Use
Use DDP for general multi-GPU training on a single or multiple nodes. Use FSDP when model parameters, gradients, and optimizer states exceed the memory of a single GPU.
Decision Tree
- Does your model fit on one GPU?
- YES: Use
DistributedDataParallel(DDP). - NO: Use
Fully Sharded Data Parallel(FSDP).
- YES: Use
- Are you launching the job?
- USE:
torchrunto handle environmental setup and fault recovery.
- USE:
- Are you saving a checkpoint?
- DO: Only save on
rank == 0to avoid file corruption and redundant I/O.
- DO: Only save on
Workflows
-
Setting Up a DDP Training Job
- Initialize the process group using
dist.init_process_group()with appropriate backend (e.g., ‘nccl’). - Set the current device for the process using
torch.cuda.set_device(rank). - Wrap the model with
DistributedDataParallel. - Wrap the dataset with a
DistributedSamplerto ensure unique data shards per process. - Clean up the process group using
dist.destroy_process_group()after training.
- Initialize the process group using
-
Checkpointing in Distributed Environments
- Check if the current process is rank 0 (
dist.get_rank() == 0). - Only rank 0 saves the model state dict to disk.
- Call
dist.barrier()to ensure all other processes wait until the file is written. - All processes load the checkpoint using
torch.load(..., map_location=...). - Resume training or perform evaluation.
- Check if the current process is rank 0 (
-
Launching with torchrun
- Refactor training code to read
LOCAL_RANKandRANKfrom environment variables. - Remove manual
mp.spawn()logic and usedist.init_process_group(backend='nccl')without rank/world_size args. - Execute the script via
torchrun --nproc_per_node=G script.py. torchrunhandles process spawning, master address setup, and fault recovery.
- Refactor training code to read
Non-Obvious Insights
- Multi-Process vs Multi-Thread: DDP is multi-process, whereas
DataParallelis single-process multi-threaded. DDP is significantly faster because it avoids Python’s Global Interpreter Lock (GIL) contention. - Mapping Locations: The
map_locationargument intorch.loadis mandatory in DDP to prevent multiple processes from attempting to load tensors into the same GPU (usually rank 0), which would cause an Out of Memory (OOM) error. - Synchronization Points: In DDP, the constructor, forward pass, and backward pass act as distributed synchronization points where processes communicate gradients.
Evidence
- “GPU devices cannot be shared across DDP processes (i.e. one GPU for one DDP process).” (https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- “In DDP, the constructor, the forward pass, and the backward pass are distributed synchronization points.” (https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)
Scripts
scripts/pytorch-distributed_tool.py: Boilerplate for atorchrun-compatible DDP script.scripts/pytorch-distributed_tool.js: Node.js wrapper to launchtorchruncommands.
Dependencies
- torch
- nccl (for GPU communication)
- gloo (for CPU-based distributed testing)