[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)

Signed-off-by: n00909098 <nguyen.kha.long@huawei.com>
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: herotai214 <herotai214@gmail.com>
Signed-off-by: Khuong Le <khuong.le.manh@huawei.com>
Signed-off-by: Khuong Le <lemanhkhuong2611@gmail.com>
Co-authored-by: n00909098 <nguyen.kha.long@huawei.com>
Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: herotai214 <herotai214@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Khuong Le <khuong.le.manh@huawei.com>
Co-authored-by: Khuong Le <lemanhkhuong2611@gmail.com>
This commit is contained in:
Chenguang Zheng
2025-11-12 10:58:33 +08:00
committed by GitHub
parent cbb799e314
commit 4ccffe561f
31 changed files with 5026 additions and 42 deletions

View File

@@ -9,6 +9,7 @@ from vllm.config.compilation import (
PassConfig,
)
from vllm.config.device import DeviceConfig
from vllm.config.ec_transfer import ECTransferConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
@@ -54,6 +55,8 @@ __all__ = [
"PassConfig",
# From vllm.config.device
"DeviceConfig",
# From vllm.config.ec_transfer
"ECTransferConfig",
# From vllm.config.kv_events
"KVEventsConfig",
# From vllm.config.kv_transfer

110
vllm/config/ec_transfer.py Normal file
View File

@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import uuid
from dataclasses import field
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
ECProducer = Literal["ec_producer"]
ECConsumer = Literal["ec_consumer"]
ECRole = Literal[ECProducer, ECConsumer]
@config
@dataclass
class ECTransferConfig:
"""Configuration for distributed EC cache transfer."""
ec_connector: str | None = None
"""The EC connector for vLLM to transmit EC caches between vLLM instances.
"""
engine_id: str | None = None
"""The engine id for EC transfers."""
ec_buffer_device: str | None = "cuda"
"""The device used by ec connector to buffer the EC cache.
Currently only support 'cuda'."""
ec_buffer_size: float = 1e9
"""The buffer size for TorchDistributedConnector. Measured in number of
bytes. Recommended value: 1e9 (about 1GB)."""
ec_role: ECRole | None = None
"""Whether this vLLM instance produces, consumes EC cache, or both. Choices
are 'ec_producer', 'ec_consumer'."""
ec_rank: int | None = None
"""The rank of this vLLM instance in the EC cache transfer. Typical value:
0 for encoder, 1 for pd instance.
Currently only 1P1D is supported."""
ec_parallel_size: int = 1
"""The number of parallel instances for EC cache transfer. For
PyNcclConnector, this should be 2."""
ec_ip: str = "127.0.0.1"
"""The EC connector ip, used to build distributed connection."""
ec_port: int = 14579
"""The EC connector port, used to build distributed connection."""
ec_connector_extra_config: dict[str, Any] = field(default_factory=dict)
"""any extra config that the connector may need."""
ec_connector_module_path: str | None = None
"""The Python module path to dynamically load the EC connector from.
Only supported in V1."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
if self.ec_role is not None and self.ec_role not in get_args(ECRole):
raise ValueError(
f"Unsupported ec_role: {self.ec_role}. "
f"Supported roles are {get_args(ECRole)}"
)
if self.ec_connector is not None and self.ec_role is None:
raise ValueError(
"Please specify ec_role when ec_connector "
f"is set, supported roles are {get_args(ECRole)}"
)
@property
def is_ec_transfer_instance(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECRole)
@property
def is_ec_producer(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECProducer)
@property
def is_ec_consumer(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECConsumer)
def get_from_extra_config(self, key, default) -> Any:
return self.ec_connector_extra_config.get(key, default)

View File

@@ -28,6 +28,7 @@ from vllm.utils import random_uuid
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig
from .ec_transfer import ECTransferConfig
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
@@ -103,6 +104,8 @@ class VllmConfig:
"""The configurations for distributed KV cache transfer."""
kv_events_config: KVEventsConfig | None = None
"""The configurations for event publishing."""
ec_transfer_config: ECTransferConfig | None = None
"""The configurations for distributed EC cache transfer."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
@@ -183,6 +186,10 @@ class VllmConfig:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.ec_transfer_config:
vllm_factors.append(self.ec_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.ec_transfer.ec_transfer_state import (
ensure_ec_transfer_initialized,
get_ec_transfer,
has_ec_transfer,
)
__all__ = [
"get_ec_transfer",
"ensure_ec_transfer_initialized",
"has_ec_transfer",
]

View File

@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
ECConnectorBase Class for Distributed Encoder Cache &
P2P Encoder cache communication in V1
The class provides the following primitives:
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save Encoder cache.
check_caches_exist() - Check whether Encoder cache of requests exist
update_state_after_alloc() - update ECConnector state after
allocate. This will decide to load the cache or not
request_finished() - called when a request is finished,
free the cache with the requests
Worker-side: runs in each worker, loads/saves Encoder Cache to/from
the Connector based on the metadata.
start_load_ec() - starts loading all ECs (maybe async)
wait_for_save() - blocks until all saves are done
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
"""
import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
import torch
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ECConnectorOutput
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
class ECConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
# Connector running in the worker process
WORKER = 1
class ECConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler ECConnector and Worker ECConnector.
"""
pass
class ECConnectorBase(ABC):
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
self._connector_metadata: ECConnectorMetadata | None = None
self._vllm_config = vllm_config
self._role = role
if vllm_config.ec_transfer_config is not None:
self._is_producer = vllm_config.ec_transfer_config.is_ec_producer
else:
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
@property
def role(self) -> ECConnectorRole:
return self._role
@property
def is_producer(self) -> bool:
return self._is_producer
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(self, connector_metadata: ECConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
before the model execution. The metadata will be used for runtime
EC cache loading.
Args:
connector_metadata (dict): the connector metadata.
"""
self._connector_metadata = connector_metadata
def clear_connector_metadata(self) -> None:
"""Clear the connector metadata.
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = None
def _get_connector_metadata(self) -> ECConnectorMetadata:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata
def register_caches(
self,
ec_caches: dict[str, torch.Tensor],
):
"""
Initialize with the EC caches.
Args:
ec_caches: dictionary of encoder cache
"""
# TODO: Implement this later for P2P feature
return
@abstractmethod
def start_load_caches(
self, encoder_cache: dict[str, torch.Tensor], **kwargs
) -> None:
"""
Start loading the cache from the connector into vLLM's encoder cache.
This method loads the encoder cache based on metadata provided by the scheduler.
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector.
"""
pass
@abstractmethod
def save_caches(
self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs
) -> None:
"""
Save the encoder cache to the connector.
This method saves the encoder cache from the worker's local storage
to shared storage or another external connector.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
mm_hash (str): The hash of the multimodal data whose cache is being saved.
kwargs (dict): Additional keyword arguments for the connector.
"""
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
# ==============================
# Scheduler-side methods
# ==============================
@abstractmethod
def has_caches(
self,
request: "Request",
) -> list[bool]:
"""
Check if encoder cache exists for each mm data of requests
Args:
request (Request): the request object.
Returns:
A list bool where ith value is True if cache exist for
ith mm_data of requests
"""
pass
@abstractmethod
def update_state_after_alloc(self, request: "Request", index: int):
"""
Update ECConnector state to decide allocate cache for requests
Args:
request (Request): the request object.
"""
pass
@abstractmethod
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> ECConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
pass
def update_connector_output(self, connector_output: ECConnectorOutput):
"""
Update ECConnector state from worker-side connectors output.
Args:
connector_output (ECConnectorOutput): the worker-side
connectors output.
"""
return
def request_finished(
self, request: "Request"
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its encoder cache is freed.
Returns:
True if the request is being saved/sent asynchronously and cached
should not be freed until the request_id is returned from
get_finished().
"""
return False, None

View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
# yapf: disable
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorRole,
)
from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING:
from vllm.config import ECTransferConfig, VllmConfig
logger = init_logger(__name__)
class ECConnectorFactory:
_registry: dict[str, Callable[[], type[ECConnectorBase]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[ECConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(
cls,
config: "VllmConfig",
role: ECConnectorRole,
) -> ECConnectorBase:
ec_transfer_config = config.ec_transfer_config
if ec_transfer_config is None:
raise ValueError("ec_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(ec_transfer_config)
logger.info(
"Creating connector with name: %s and engine_id: %s",
connector_cls.__name__,
ec_transfer_config.engine_id,
)
# Connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
# - Should only be used inside the Scheduler class
# Worker connector:
# - Co-locate with worker process
return connector_cls(config, role)
@classmethod
def get_connector_class(
cls, ec_transfer_config: "ECTransferConfig"
) -> type[ECConnectorBase]:
"""Get the connector class by name."""
connector_name = ec_transfer_config.ec_connector
if connector_name is None:
raise ValueError("EC connect must not be None")
elif connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = ec_transfer_config.ec_connector_module_path
if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
ECConnectorFactory.register_connector(
"ECSharedStorageConnector",
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector",
"ECSharedStorageConnector",
)

View File

@@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import safetensors
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class MMMeta:
mm_hash: str
num_token: int
@staticmethod
def make_meta(mm_hash, num_token) -> "MMMeta":
return MMMeta(mm_hash=mm_hash, num_token=num_token)
@dataclass
class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
mm_datas: list[MMMeta]
def __init__(self):
self.mm_datas = []
def add_mm_data(self, mm_data: MMMeta):
self.mm_datas.append(mm_data)
class ECSharedStorageConnector(ECConnectorBase):
# NOTE: This is Simple debug implementation of the EC connector.
# It save / load the EC cache to / from the disk.
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
# req_id -> index
self._mm_datas_need_loads: dict[str, int] = {}
transfer_config = vllm_config.ec_transfer_config
if transfer_config is not None:
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
logger.debug(transfer_config)
logger.debug("Shared storage path is %s", self._storage_path)
else:
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
def start_load_caches(self, encoder_cache, **kwargs) -> None:
"""
Start loading the cache from the connector into vLLM's encoder cache.
This method loads the encoder cache based on metadata provided by the scheduler.
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector.
"""
# Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning(
(
"In connector.start_load_caches, ",
"but the connector metadata is None",
)
)
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
"""
Save the encoder cache to the connector.
This method saves the encoder cache from the worker's local storage
to shared storage or another external connector.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
mm_hash (str): The hash of the multimodal data whose cache is being saved.
kwargs (dict): Additional keyword arguments for the connector.
"""
# Return if it is PD Instance
if not self.is_producer:
return
filename = self._generate_filename_debug(mm_hash)
ec_cache = encoder_cache[mm_hash]
tensors = {"ec_cache": ec_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
logger.debug("Save cache successful for mm_hash %s", mm_hash)
def has_caches(
self,
request: "Request",
) -> list[bool]:
"""
Check if cache exist externally for each mm_data of request
Args:
request (Request): the request object.
Returns:
List of bool indicate that ith mm_data exist in cache or not
"""
result = []
for feature in request.mm_features:
result.append(self._found_match_for_mm_data(feature.identifier))
return result
def update_state_after_alloc(
self,
request: "Request",
index: int,
) -> None:
"""
Update ECConnector state after encoder cache allocation.
"""
mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index)
# Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> ECConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
This only build for load mm_data only
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ECSharedStorageConnectorMetadata()
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
self._mm_datas_need_loads.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_mm_data(self, mm_hash) -> bool:
"""Check if the cache is hit for the request."""
filename = self._generate_filename_debug(mm_hash)
return os.path.exists(filename)
def _generate_foldername_debug(
self,
mm_hash: str,
create_folder: bool = True, # <- now defaults to True
) -> str:
"""
Return the folder in which the cache for this mm_hash lives.
If `create_folder` is True (default) the directory is created
recursively the first time it is needed.
"""
foldername = os.path.join(self._storage_path, mm_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(self, mm_hash: str) -> str:
"""
Return the full path of the safetensors file for this mm_hash.
Ensures the parent directory exists because
`_generate_foldername_debug` is called with its default
(`create_folder=True`).
"""
foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created
return os.path.join(foldername, "encoder_cache.safetensors")

View File

@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm import envs
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorRole,
)
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
if TYPE_CHECKING:
from vllm.config import VllmConfig
_EC_CONNECTOR_AGENT: ECConnectorBase | None = None
def get_ec_transfer() -> ECConnectorBase:
assert _EC_CONNECTOR_AGENT is not None, "disaggregated EC cache is not initialized"
return _EC_CONNECTOR_AGENT
def has_ec_transfer() -> bool:
return _EC_CONNECTOR_AGENT is not None
def ensure_ec_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize EC cache connector.
"""
global _EC_CONNECTOR_AGENT
if vllm_config.ec_transfer_config is None:
return
if (
vllm_config.ec_transfer_config.is_ec_transfer_instance
and _EC_CONNECTOR_AGENT is None
):
if envs.VLLM_USE_V1:
_EC_CONNECTOR_AGENT = ECConnectorFactory.create_connector(
config=vllm_config, role=ECConnectorRole.WORKER
)
else:
raise ValueError("V0 is no longer supported")

View File

@@ -38,6 +38,7 @@ from vllm.config import (
CompilationConfig,
ConfigType,
DeviceConfig,
ECTransferConfig,
EPLBConfig,
KVEventsConfig,
KVTransferConfig,
@@ -527,6 +528,8 @@ class EngineArgs:
kv_transfer_config: KVTransferConfig | None = None
kv_events_config: KVEventsConfig | None = None
ec_transfer_config: ECTransferConfig | None = None
generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
override_generation_config: dict[str, Any] = get_field(
@@ -1105,6 +1108,9 @@ class EngineArgs:
"--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
)
vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
vllm_group.add_argument(
"--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
)
vllm_group.add_argument(
"--compilation-config", "-O", **vllm_kwargs["compilation_config"]
)
@@ -1676,6 +1682,7 @@ class EngineArgs:
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
additional_config=self.additional_config,
)

View File

@@ -49,10 +49,18 @@ def kernel_warmup(worker: "Worker"):
except NotImplementedError:
return False
if not worker.model_runner.is_pooling_model and all(
_is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups
for group in groups
# NOTE: we add check for empty attn_groups to avoid errors when
# deploying models such as E instances and encoder-only models.
# As for those models, worker.model_runner.attn_groups is empty.
# This change is made during EPD feature development.
if (
not worker.model_runner.is_pooling_model
and worker.model_runner.attn_groups
and all(
_is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups
for group in groups
)
):
logger.info("Warming up FlashInfer attention.")
# Warmup with mixed batch containing both prefill and decode tokens

View File

@@ -14,6 +14,7 @@ if TYPE_CHECKING:
import numpy.typing as npt
import torch
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
else:
ECConnectorMetadata = object
KVConnectorMetadata = object
LoRARequest = object
MultiModalFeatureSpec = object
@@ -188,6 +190,9 @@ class SchedulerOutput:
# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None
# EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None
@dataclass
class GrammarOutput:

View File

@@ -7,6 +7,11 @@ from collections.abc import Iterable
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
@@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorRole,
SupportsHMA,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
@@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface):
self.kv_events_config,
self.parallel_config.data_parallel_rank,
)
self.ec_connector = None
if self.vllm_config.ec_transfer_config is not None:
self.ec_connector = ECConnectorFactory.create_connector(
config=self.vllm_config, role=ECConnectorRole.SCHEDULER
)
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0
@@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
external_load_encoder_input: list[int] = []
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
request.num_computed_tokens,
@@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
@@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
external_load_encoder_input = []
new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work.
@@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface):
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
num_computed_tokens,
@@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Allocate for external load encoder cache
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
@@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface):
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
@@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface):
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
meta: KVConnectorMetadata = self.connector.build_connector_meta(
scheduler_output
)
scheduler_output.kv_connector_metadata = meta
# Build the connector meta for ECConnector
if self.ec_connector is not None:
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
scheduler_output
)
scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output
@@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens: int,
num_new_tokens: int,
encoder_compute_budget: int,
) -> tuple[list[int], int, int]:
) -> tuple[list[int], int, int, list[int]]:
"""
Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly.
@@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface):
in this step, i.e.,
[num_computed_tokens, num_computed_tokens + num_new_tokens).
- It is not already computed and stored in the encoder cache.
- It is not exist on remote encoder cache (via ECConnector)
- There is sufficient encoder token budget to process it.
- The encoder cache has space to store it.
@@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface):
blocks and externally cached blocks (via KVConnector).
"""
if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_compute_budget
return [], num_new_tokens, encoder_compute_budget, []
encoder_inputs_to_schedule: list[int] = []
mm_features = request.mm_features
assert mm_features is not None
assert len(mm_features) > 0
external_load_encoder_input = []
# Check remote cache first
if self.ec_connector is not None:
remote_cache_has_item = self.ec_connector.has_caches(request)
# NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
@@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
continue
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
@@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface):
encoder_inputs_to_schedule,
num_new_tokens,
encoder_compute_budget,
external_load_encoder_input,
)
def get_grammar_bitmask(

View File

@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import torch
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
else:
@@ -136,6 +138,13 @@ class KVConnectorOutput:
)
@dataclass
class ECConnectorOutput:
# [mm_hash]
finished_sending: set[str] | None = None
finished_recving: set[str] | None = None
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
@@ -167,6 +176,8 @@ class ModelRunnerOutput:
kv_connector_output: KVConnectorOutput | None = None
ec_connector_output: ECConnectorOutput | None = None
# req_id -> num_nans_in_logits
num_nans_in_logits: dict[str, int] | None = None
@@ -192,6 +203,41 @@ class DraftTokenIds:
draft_token_ids: list[list[int]]
def make_empty_encoder_model_runner_output(
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
"""
Create a ModelRunnerOutput stub that contains the correct
per-request bookkeeping but no generated data yet.
"""
if not scheduler_output.num_scheduled_tokens:
return EMPTY_MODEL_RUNNER_OUTPUT
# Convert to list so we get a deterministic, indexable sequence
req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())
# Give every request its own contiguous index
req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}
# No tokens generated yet ⇒ one empty list per request
sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
# Pooler outputs are not available yet ⇒ use None placeholders
pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
kv_connector_output=None,
ec_connector_output=None,
num_nans_in_logits=None,
)
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},

View File

@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define EC connector functionality mixin for model runners.
"""
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import (
TYPE_CHECKING, # noqa: UP035
)
import torch
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase
from vllm.logger import init_logger
from vllm.v1.outputs import ECConnectorOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU)
class ECConnectorModelRunnerMixin:
@staticmethod
def maybe_save_ec_to_connector(
encoder_cache: dict[str, torch.Tensor],
mm_hash: str,
):
if not has_ec_transfer():
logger.debug("Not have ec transfer please check")
return
connector = get_ec_transfer()
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
@staticmethod
def get_finished_ec_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[set[str] | None, set[str] | None]:
if has_ec_transfer():
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
return None, None
@staticmethod
def maybe_get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> AbstractContextManager[ECConnectorOutput | None]:
return (
ECConnectorModelRunnerMixin._get_ec_connector_output(
scheduler_output, encoder_cache, **kwargs
)
if has_ec_transfer()
else nullcontext()
)
# This context manager must be used within an active forward context.
# It encapsulates the entire EC conector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> Generator[ECConnectorOutput, None, None]:
output = ECConnectorOutput()
ec_connector = get_ec_transfer()
assert isinstance(ec_connector, ECConnectorBase)
assert scheduler_output.ec_connector_metadata is not None
ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata)
if not ec_connector.is_producer:
ec_connector.start_load_caches(encoder_cache, **kwargs)
try:
yield output
finally:
output.finished_sending, output.finished_recving = (
ec_connector.get_finished(scheduler_output.finished_req_ids)
)
ec_connector.clear_connector_metadata()

View File

@@ -35,6 +35,7 @@ from vllm.config import (
get_layers_from_vllm_config,
update_config,
)
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
@@ -114,12 +115,14 @@ from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
DraftTokenIds,
ECConnectorOutput,
KVConnectorOutput,
LogprobsLists,
LogprobsTensors,
ModelRunnerOutput,
PoolerOutput,
SamplerOutput,
make_empty_encoder_model_runner_output,
)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
@@ -134,6 +137,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@@ -237,9 +241,12 @@ class ExecuteModelState(NamedTuple):
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
ec_connector_output: ECConnectorOutput | None
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
class GPUModelRunner(
LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
):
def __init__(
self,
vllm_config: VllmConfig,
@@ -1873,6 +1880,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
output,
is_embed=pos_info.is_embed,
)
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
def _gather_mm_embeddings(
self,
@@ -2191,20 +2200,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
torch.Tensor,
IntermediateTensors | None,
dict[str, Any],
ECConnectorOutput | None,
]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
is_first_rank = get_pp_group().is_first_rank
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
ec_connector_output = None
if (
self.supports_mm_inputs
and is_first_rank
and not self.model_config.is_encoder_decoder
):
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
@@ -2284,6 +2300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
)
def _sample(
@@ -2508,6 +2525,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update persistent batch states.
self._update_states(scheduler_output)
if has_ec_transfer() and get_ec_transfer().is_producer:
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
return make_empty_encoder_model_runner_output(scheduler_output)
if not num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
@@ -2583,6 +2608,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors
)
@@ -2699,6 +2725,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
ec_connector_output,
)
return None
@@ -2720,6 +2747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
ec_connector_output,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
@@ -2811,6 +2839,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
ec_connector_output=ec_connector_output
if self.supports_mm_inputs
else None,
num_nans_in_logits=num_nans_in_logits,
)
@@ -4797,7 +4828,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():

View File

@@ -20,6 +20,7 @@ from vllm.distributed import (
init_distributed_environment,
set_custom_all_reduce,
)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
@@ -887,3 +888,7 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size,
)
# Init ec connector here before KV caches caches init
# NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
ensure_ec_transfer_initialized(vllm_config)