Implement Async Scheduling (#19970)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-07-14 23:01:46 -07:00
committed by GitHub
parent 85bd6599e4
commit d4d309409f
11 changed files with 508 additions and 148 deletions

View File

@@ -484,6 +484,8 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
async_scheduling: bool = SchedulerConfig.async_scheduling
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
@@ -921,6 +923,8 @@ class EngineArgs:
scheduler_group.add_argument(
"--disable-hybrid-kv-cache-manager",
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
scheduler_group.add_argument("--async-scheduling",
**scheduler_kwargs["async_scheduling"])
# vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig)
@@ -1206,6 +1210,26 @@ class EngineArgs:
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port
if self.async_scheduling:
# Async scheduling does not work with the uniprocess backend.
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "mp"
logger.info("Using mp-based distributed executor backend "
"for async scheduling.")
if self.distributed_executor_backend == "uni":
raise ValueError("Async scheduling is not supported with "
"uni-process backend.")
if self.pipeline_parallel_size > 1:
raise ValueError("Async scheduling is not supported with "
"pipeline-parallel-size > 1.")
# Currently, async scheduling does not support speculative decoding.
# TODO(woosuk): Support it.
if self.speculative_config is not None:
raise ValueError(
"Currently, speculative decoding is not supported with "
"async scheduling.")
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
@@ -1286,6 +1310,7 @@ class EngineArgs:
long_prefill_token_threshold=self.long_prefill_token_threshold,
disable_hybrid_kv_cache_manager=self.
disable_hybrid_kv_cache_manager,
async_scheduling=self.async_scheduling,
)
if not model_config.is_multimodal_model and self.default_mm_loras: