diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index bcc689070..751a29795 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -9,6 +9,7 @@ from vllm.config import ( ECTransferConfig, KVTransferConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig, @@ -53,6 +54,7 @@ def create_scheduler( num_speculative_tokens: int | None = None, skip_tokenizer_init: bool = False, async_scheduling: bool = False, + pipeline_parallel_size: int = 1, use_ec_connector: bool = False, ec_role: str | None = None, ) -> Scheduler | AsyncScheduler: @@ -133,6 +135,7 @@ def create_scheduler( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, + parallel_config=ParallelConfig(pipeline_parallel_size=pipeline_parallel_size), kv_transfer_config=kv_transfer_config, speculative_config=speculative_config, ec_transfer_config=ec_transfer_config, diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a84acd8e6..6f36cbf9b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -563,11 +563,6 @@ class VllmConfig: if self.scheduler_config.async_scheduling: # Async scheduling explicitly enabled, hard fail any incompatibilities. - if self.parallel_config.pipeline_parallel_size > 1: - raise ValueError( - "Async scheduling is not yet compatible with " - "pipeline_parallel_size > 1." - ) # Currently, async scheduling only support eagle speculative # decoding. if self.speculative_config is not None: @@ -589,14 +584,7 @@ class VllmConfig: ) elif self.scheduler_config.async_scheduling is None: # Enable async scheduling unless there is an incompatible option. - if self.parallel_config.pipeline_parallel_size > 1: - logger.warning_once( - "Async scheduling is not yet supported with " - "pipeline_parallel_size > 1 and will be disabled.", - scope="local", - ) - self.scheduler_config.async_scheduling = False - elif ( + if ( self.speculative_config is not None and self.speculative_config.method not in get_args(EagleModelTypes) ): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index bdd6d2a3c..1660d5189 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -283,6 +283,13 @@ class Scheduler(SchedulerInterface): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] + # do not schedule another step for the same request while it still has + # output placeholders for PP. + # TODO: support PP + async scheduling without this limit + if self.use_pp and request.num_output_placeholders > 0: + req_index += 1 + continue + if ( request.num_output_placeholders > 0 # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 7b427b4a6..4735035d7 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -411,9 +411,9 @@ class MultiprocExecutor(Executor): @cached_property def max_concurrent_batches(self) -> int: - if self.scheduler_config.async_scheduling: - return 2 - return self.parallel_config.pipeline_parallel_size + # PP requires PP-size concurrent batches to fill the pipeline. + pp_size = self.parallel_config.pipeline_parallel_size + return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size def _get_output_rank(self) -> int: # Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1 diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 292fa877f..c8c6185b6 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -111,9 +111,8 @@ class RayDistributedExecutor(Executor): """Ray distributed executor supports pipeline parallelism, meaning that it allows PP size batches to be executed concurrently. """ - if self.scheduler_config.async_scheduling: - return 2 - return self.parallel_config.pipeline_parallel_size + pp_size = self.parallel_config.pipeline_parallel_size + return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size def shutdown(self) -> None: if logger: