[Bugfix] Propagate compilation_time from workers to main process for TP>1 (#35503)

Signed-off-by: Huy Do <huydhn@gmail.com>
This commit is contained in:
Huy Do
2026-02-27 21:03:22 -08:00
committed by GitHub
parent dea268336f
commit 7b346ba8ed
4 changed files with 20 additions and 5 deletions

View File

@@ -115,7 +115,15 @@ class Executor(ABC):
underlying workers.
"""
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
self.collective_rpc("compile_or_warm_up_model")
compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model")
# Propagate compilation time from workers back to the main process.
# With TP>1, compilation happens in worker processes, so the main
# process config is never updated. Use max across workers since they
# compile in parallel.
if compilation_times:
self.vllm_config.compilation_config.compilation_time = max(
compilation_times
)
def register_failure_callback(self, callback: FailureCallback): # noqa: B027
"""

View File

@@ -118,11 +118,12 @@ class CPUWorker(Worker):
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
return self.compilation_config.compilation_time
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]

View File

@@ -480,7 +480,7 @@ class Worker(WorkerBase):
self.model_runner.initialize_kv_cache(kv_cache_config)
@instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
warmup_sizes = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
@@ -605,6 +605,8 @@ class Worker(WorkerBase):
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return self.compilation_config.compilation_time
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()

View File

@@ -87,8 +87,12 @@ class WorkerBase:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
def compile_or_warm_up_model(self) -> float:
"""Prepare model for execution through compilation/warmup.
Returns:
The accumulated compilation time in seconds.
"""
raise NotImplementedError
def check_health(self) -> None: