[P/D] Dynamic kv_output_aggregator collect size (#26734)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-10-22 18:07:58 +02:00
committed by GitHub
parent 58fab50d82
commit 4dfdb821c8
7 changed files with 90 additions and 19 deletions

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from concurrent.futures import Future
from functools import cached_property
from typing import Literal, TypeVar, overload
from typing import TYPE_CHECKING, Literal, TypeVar, overload
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
@@ -19,6 +19,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
logger = init_logger(__name__)
_R = TypeVar("_R")
@@ -233,10 +236,10 @@ class Executor(ABC):
"""Shutdown the executor."""
self.collective_rpc("shutdown")
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None:
"""Init KVOutputAggregator"""
self.kv_output_aggregator = KVOutputAggregator(
finished_count or self.parallel_config.world_size
self.kv_output_aggregator = KVOutputAggregator.from_connector(
connector, self.parallel_config.world_size
)
@cached_property # Avoid unnecessary RPC calls