[V1] AsyncLLM data parallel (#13923)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-27 16:14:41 -07:00
committed by GitHub
parent 112b3e5b3b
commit 15dac210f0
18 changed files with 722 additions and 156 deletions

View File

@@ -66,11 +66,17 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
self.stat_loggers: list[StatLoggerBase] = []
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = []
if self.log_stats:
if logger.isEnabledFor(logging.INFO):
self.stat_loggers.append(LoggingStatLogger())
self.stat_loggers.append(PrometheusStatLogger(vllm_config))
for i in range(vllm_config.parallel_config.data_parallel_size):
loggers: list[StatLoggerBase] = []
if logger.isEnabledFor(logging.INFO):
loggers.append(LoggingStatLogger(engine_index=i))
loggers.append(
PrometheusStatLogger(vllm_config, engine_index=i))
self.stat_loggers.append(loggers)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
@@ -329,6 +335,7 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
@@ -350,12 +357,13 @@ class AsyncLLM(EngineClient):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_index: int = 0,
):
if not self.log_stats:
return
assert scheduler_stats is not None
for stat_logger in self.stat_loggers:
for stat_logger in self.stat_loggers[engine_index]:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
@@ -393,8 +401,9 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None,
model_output=None,
) -> None:
for stat_logger in self.stat_loggers:
stat_logger.log()
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None:
logger.debug("Called check_health.")