diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 2e1ca74ed..363078aef 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from collections.abc import Callable from typing import TYPE_CHECKING, Any, Literal import torch -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self @@ -182,9 +183,12 @@ class ParallelConfig: threshold, microbatching will be used. Otherwise, the request will be processed in a single batch.""" - disable_nccl_for_dp_synchronization: bool = False + disable_nccl_for_dp_synchronization: bool = Field(default=None) """Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py - to use Gloo instead of NCCL for its all reduce""" + to use Gloo instead of NCCL for its all reduce. + + Defaults to True when async scheduling is enabled, False otherwise. + """ ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" @@ -292,6 +296,12 @@ class ParallelConfig: should only be set by API server scale-out. """ + @field_validator("disable_nccl_for_dp_synchronization", mode="wrap") + @classmethod + def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: + """Skip validation if the value is `None` when initialisation is delayed.""" + return None if value is None else handler(value) + @model_validator(mode="after") def _validate_parallel_config(self) -> Self: if self._api_process_rank >= self._api_process_count: diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 781d13c69..1bceaa933 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -209,9 +209,7 @@ class SchedulerConfig: @classmethod def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: """Skip validation if the value is `None` when initialisation is delayed.""" - if value is None: - return value - return handler(value) + return None if value is None else handler(value) def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None: if is_encoder_decoder: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f9318c8c6..88c2e100a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -629,20 +629,22 @@ class VllmConfig: else: self.scheduler_config.async_scheduling = True - if ( - self.scheduler_config.async_scheduling - and not self.parallel_config.disable_nccl_for_dp_synchronization - ): - logger.info_once( - "Disabling NCCL for DP synchronization when using async scheduling." - ) - self.parallel_config.disable_nccl_for_dp_synchronization = True - logger.info_once( "Asynchronous scheduling is %s.", "enabled" if self.scheduler_config.async_scheduling else "disabled", ) + if self.parallel_config.disable_nccl_for_dp_synchronization is None: + if self.scheduler_config.async_scheduling: + logger.info_once( + "Disabling NCCL for DP synchronization " + "when using async scheduling.", + scope="local", + ) + self.parallel_config.disable_nccl_for_dp_synchronization = True + else: + self.parallel_config.disable_nccl_for_dp_synchronization = False + from vllm.platforms import current_platform if ( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7631cd61d..b7f3969ee 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -413,7 +413,7 @@ class EngineArgs: ubatch_size: int = ParallelConfig.ubatch_size dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold - disable_nccl_for_dp_synchronization: bool = ( + disable_nccl_for_dp_synchronization: bool | None = ( ParallelConfig.disable_nccl_for_dp_synchronization ) eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")