[Misc] Allow enabling NCCL for DP sync when async scheduling (#32197)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-01-12 18:03:08 -08:00
committed by GitHub
parent 78d13ea9de
commit 9273a427b5
4 changed files with 26 additions and 16 deletions

View File

@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import Field, model_validator
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self
@@ -182,9 +183,12 @@ class ParallelConfig:
threshold, microbatching will be used. Otherwise, the request will be
processed in a single batch."""
disable_nccl_for_dp_synchronization: bool = False
disable_nccl_for_dp_synchronization: bool = Field(default=None)
"""Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py
to use Gloo instead of NCCL for its all reduce"""
to use Gloo instead of NCCL for its all reduce.
Defaults to True when async scheduling is enabled, False otherwise.
"""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
@@ -292,6 +296,12 @@ class ParallelConfig:
should only be set by API server scale-out.
"""
@field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
return None if value is None else handler(value)
@model_validator(mode="after")
def _validate_parallel_config(self) -> Self:
if self._api_process_rank >= self._api_process_count: