pytorch-distributed

📁 cuba6112/skillfactory 📅 4 days ago
1
总安装量
1
周安装量
#50200
全站排名
安装命令
npx skills add https://github.com/cuba6112/skillfactory --skill pytorch-distributed

Agent 安装分布

mcpjam 1
claude-code 1
junie 1
windsurf 1
zencoder 1
crush 1

Skill 文档

Overview

PyTorch Distributed enables training models across multiple GPUs and nodes. DistributedDataParallel (DDP) is the standard for multi-process data parallelism, while Fully Sharded Data Parallel (FSDP) shards model state to allow training models too large for a single GPU.

When to Use

Use DDP for general multi-GPU training on a single or multiple nodes. Use FSDP when model parameters, gradients, and optimizer states exceed the memory of a single GPU.

Decision Tree

  1. Does your model fit on one GPU?
    • YES: Use DistributedDataParallel (DDP).
    • NO: Use Fully Sharded Data Parallel (FSDP).
  2. Are you launching the job?
    • USE: torchrun to handle environmental setup and fault recovery.
  3. Are you saving a checkpoint?
    • DO: Only save on rank == 0 to avoid file corruption and redundant I/O.

Workflows

  1. Setting Up a DDP Training Job

    1. Initialize the process group using dist.init_process_group() with appropriate backend (e.g., ‘nccl’).
    2. Set the current device for the process using torch.cuda.set_device(rank).
    3. Wrap the model with DistributedDataParallel.
    4. Wrap the dataset with a DistributedSampler to ensure unique data shards per process.
    5. Clean up the process group using dist.destroy_process_group() after training.
  2. Checkpointing in Distributed Environments

    1. Check if the current process is rank 0 (dist.get_rank() == 0).
    2. Only rank 0 saves the model state dict to disk.
    3. Call dist.barrier() to ensure all other processes wait until the file is written.
    4. All processes load the checkpoint using torch.load(..., map_location=...).
    5. Resume training or perform evaluation.
  3. Launching with torchrun

    1. Refactor training code to read LOCAL_RANK and RANK from environment variables.
    2. Remove manual mp.spawn() logic and use dist.init_process_group(backend='nccl') without rank/world_size args.
    3. Execute the script via torchrun --nproc_per_node=G script.py.
    4. torchrun handles process spawning, master address setup, and fault recovery.

Non-Obvious Insights

  • Multi-Process vs Multi-Thread: DDP is multi-process, whereas DataParallel is single-process multi-threaded. DDP is significantly faster because it avoids Python’s Global Interpreter Lock (GIL) contention.
  • Mapping Locations: The map_location argument in torch.load is mandatory in DDP to prevent multiple processes from attempting to load tensors into the same GPU (usually rank 0), which would cause an Out of Memory (OOM) error.
  • Synchronization Points: In DDP, the constructor, forward pass, and backward pass act as distributed synchronization points where processes communicate gradients.

Evidence

Scripts

  • scripts/pytorch-distributed_tool.py: Boilerplate for a torchrun-compatible DDP script.
  • scripts/pytorch-distributed_tool.js: Node.js wrapper to launch torchrun commands.

Dependencies

  • torch
  • nccl (for GPU communication)
  • gloo (for CPU-based distributed testing)

References