[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)
Signed-off-by: n00909098 <nguyen.kha.long@huawei.com> Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: herotai214 <herotai214@gmail.com> Signed-off-by: Khuong Le <khuong.le.manh@huawei.com> Signed-off-by: Khuong Le <lemanhkhuong2611@gmail.com> Co-authored-by: n00909098 <nguyen.kha.long@huawei.com> Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com> Co-authored-by: herotai214 <herotai214@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Khuong Le <khuong.le.manh@huawei.com> Co-authored-by: Khuong Le <lemanhkhuong2611@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.config.compilation import (
|
||||
PassConfig,
|
||||
)
|
||||
from vllm.config.device import DeviceConfig
|
||||
from vllm.config.ec_transfer import ECTransferConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
@@ -54,6 +55,8 @@ __all__ = [
|
||||
"PassConfig",
|
||||
# From vllm.config.device
|
||||
"DeviceConfig",
|
||||
# From vllm.config.ec_transfer
|
||||
"ECTransferConfig",
|
||||
# From vllm.config.kv_events
|
||||
"KVEventsConfig",
|
||||
# From vllm.config.kv_transfer
|
||||
|
||||
110
vllm/config/ec_transfer.py
Normal file
110
vllm/config/ec_transfer.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
ECProducer = Literal["ec_producer"]
|
||||
ECConsumer = Literal["ec_consumer"]
|
||||
ECRole = Literal[ECProducer, ECConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ECTransferConfig:
|
||||
"""Configuration for distributed EC cache transfer."""
|
||||
|
||||
ec_connector: str | None = None
|
||||
"""The EC connector for vLLM to transmit EC caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: str | None = None
|
||||
"""The engine id for EC transfers."""
|
||||
|
||||
ec_buffer_device: str | None = "cuda"
|
||||
"""The device used by ec connector to buffer the EC cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
ec_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
ec_role: ECRole | None = None
|
||||
"""Whether this vLLM instance produces, consumes EC cache, or both. Choices
|
||||
are 'ec_producer', 'ec_consumer'."""
|
||||
|
||||
ec_rank: int | None = None
|
||||
"""The rank of this vLLM instance in the EC cache transfer. Typical value:
|
||||
0 for encoder, 1 for pd instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
ec_parallel_size: int = 1
|
||||
"""The number of parallel instances for EC cache transfer. For
|
||||
PyNcclConnector, this should be 2."""
|
||||
|
||||
ec_ip: str = "127.0.0.1"
|
||||
"""The EC connector ip, used to build distributed connection."""
|
||||
|
||||
ec_port: int = 14579
|
||||
"""The EC connector port, used to build distributed connection."""
|
||||
|
||||
ec_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
ec_connector_module_path: str | None = None
|
||||
"""The Python module path to dynamically load the EC connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.ec_role is not None and self.ec_role not in get_args(ECRole):
|
||||
raise ValueError(
|
||||
f"Unsupported ec_role: {self.ec_role}. "
|
||||
f"Supported roles are {get_args(ECRole)}"
|
||||
)
|
||||
|
||||
if self.ec_connector is not None and self.ec_role is None:
|
||||
raise ValueError(
|
||||
"Please specify ec_role when ec_connector "
|
||||
f"is set, supported roles are {get_args(ECRole)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_ec_transfer_instance(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECRole)
|
||||
|
||||
@property
|
||||
def is_ec_producer(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECProducer)
|
||||
|
||||
@property
|
||||
def is_ec_consumer(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.ec_connector_extra_config.get(key, default)
|
||||
@@ -28,6 +28,7 @@ from vllm.utils import random_uuid
|
||||
from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
from .device import DeviceConfig
|
||||
from .ec_transfer import ECTransferConfig
|
||||
from .kv_events import KVEventsConfig
|
||||
from .kv_transfer import KVTransferConfig
|
||||
from .load import LoadConfig
|
||||
@@ -103,6 +104,8 @@ class VllmConfig:
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: KVEventsConfig | None = None
|
||||
"""The configurations for event publishing."""
|
||||
ec_transfer_config: ECTransferConfig | None = None
|
||||
"""The configurations for distributed EC cache transfer."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
@@ -183,6 +186,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.ec_transfer_config:
|
||||
vllm_factors.append(self.ec_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
|
||||
14
vllm/distributed/ec_transfer/__init__.py
Normal file
14
vllm/distributed/ec_transfer/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_transfer_state import (
|
||||
ensure_ec_transfer_initialized,
|
||||
get_ec_transfer,
|
||||
has_ec_transfer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_ec_transfer",
|
||||
"ensure_ec_transfer_initialized",
|
||||
"has_ec_transfer",
|
||||
]
|
||||
247
vllm/distributed/ec_transfer/ec_connector/base.py
Normal file
247
vllm/distributed/ec_transfer/ec_connector/base.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
ECConnectorBase Class for Distributed Encoder Cache &
|
||||
P2P Encoder cache communication in V1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save Encoder cache.
|
||||
check_caches_exist() - Check whether Encoder cache of requests exist
|
||||
update_state_after_alloc() - update ECConnector state after
|
||||
allocate. This will decide to load the cache or not
|
||||
request_finished() - called when a request is finished,
|
||||
free the cache with the requests
|
||||
|
||||
Worker-side: runs in each worker, loads/saves Encoder Cache to/from
|
||||
the Connector based on the metadata.
|
||||
start_load_ec() - starts loading all ECs (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ECConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ECConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class ECConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler ECConnector and Worker ECConnector.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ECConnectorBase(ABC):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
|
||||
self._connector_metadata: ECConnectorMetadata | None = None
|
||||
self._vllm_config = vllm_config
|
||||
self._role = role
|
||||
if vllm_config.ec_transfer_config is not None:
|
||||
self._is_producer = vllm_config.ec_transfer_config.is_ec_producer
|
||||
else:
|
||||
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
|
||||
|
||||
@property
|
||||
def role(self) -> ECConnectorRole:
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def is_producer(self) -> bool:
|
||||
return self._is_producer
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: ECConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
EC cache loading.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = None
|
||||
|
||||
def _get_connector_metadata(self) -> ECConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def register_caches(
|
||||
self,
|
||||
ec_caches: dict[str, torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Initialize with the EC caches.
|
||||
Args:
|
||||
ec_caches: dictionary of encoder cache
|
||||
"""
|
||||
# TODO: Implement this later for P2P feature
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_caches(
|
||||
self, encoder_cache: dict[str, torch.Tensor], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Start loading the cache from the connector into vLLM's encoder cache.
|
||||
|
||||
This method loads the encoder cache based on metadata provided by the scheduler.
|
||||
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
|
||||
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_caches(
|
||||
self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Save the encoder cache to the connector.
|
||||
|
||||
This method saves the encoder cache from the worker's local storage
|
||||
to shared storage or another external connector.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
mm_hash (str): The hash of the multimodal data whose cache is being saved.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def has_caches(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> list[bool]:
|
||||
"""
|
||||
Check if encoder cache exists for each mm data of requests
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
|
||||
Returns:
|
||||
A list bool where ith value is True if cache exist for
|
||||
ith mm_data of requests
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(self, request: "Request", index: int):
|
||||
"""
|
||||
Update ECConnector state to decide allocate cache for requests
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> ECConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_connector_output(self, connector_output: ECConnectorOutput):
|
||||
"""
|
||||
Update ECConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (ECConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request"
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its encoder cache is freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and cached
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
"""
|
||||
return False, None
|
||||
88
vllm/distributed/ec_transfer/ec_connector/factory.py
Normal file
88
vllm/distributed/ec_transfer/ec_connector/factory.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# yapf: disable
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# yapf: enable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ECTransferConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ECConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[ECConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[ECConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: ECConnectorRole,
|
||||
) -> ECConnectorBase:
|
||||
ec_transfer_config = config.ec_transfer_config
|
||||
if ec_transfer_config is None:
|
||||
raise ValueError("ec_transfer_config must be set to create a connector")
|
||||
connector_cls = cls.get_connector_class(ec_transfer_config)
|
||||
logger.info(
|
||||
"Creating connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__,
|
||||
ec_transfer_config.engine_id,
|
||||
)
|
||||
# Connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
# - Should only be used inside the Scheduler class
|
||||
# Worker connector:
|
||||
# - Co-locate with worker process
|
||||
return connector_cls(config, role)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, ec_transfer_config: "ECTransferConfig"
|
||||
) -> type[ECConnectorBase]:
|
||||
"""Get the connector class by name."""
|
||||
connector_name = ec_transfer_config.ec_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("EC connect must not be None")
|
||||
elif connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = ec_transfer_config.ec_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
return connector_cls
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
|
||||
ECConnectorFactory.register_connector(
|
||||
"ECSharedStorageConnector",
|
||||
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector",
|
||||
"ECSharedStorageConnector",
|
||||
)
|
||||
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import safetensors
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorMetadata,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMMeta:
|
||||
mm_hash: str
|
||||
num_token: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(mm_hash, num_token) -> "MMMeta":
|
||||
return MMMeta(mm_hash=mm_hash, num_token=num_token)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
|
||||
mm_datas: list[MMMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.mm_datas = []
|
||||
|
||||
def add_mm_data(self, mm_data: MMMeta):
|
||||
self.mm_datas.append(mm_data)
|
||||
|
||||
|
||||
class ECSharedStorageConnector(ECConnectorBase):
|
||||
# NOTE: This is Simple debug implementation of the EC connector.
|
||||
# It save / load the EC cache to / from the disk.
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
# req_id -> index
|
||||
self._mm_datas_need_loads: dict[str, int] = {}
|
||||
transfer_config = vllm_config.ec_transfer_config
|
||||
if transfer_config is not None:
|
||||
self._storage_path = transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
logger.debug(transfer_config)
|
||||
logger.debug("Shared storage path is %s", self._storage_path)
|
||||
else:
|
||||
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
|
||||
|
||||
def start_load_caches(self, encoder_cache, **kwargs) -> None:
|
||||
"""
|
||||
Start loading the cache from the connector into vLLM's encoder cache.
|
||||
|
||||
This method loads the encoder cache based on metadata provided by the scheduler.
|
||||
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
|
||||
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
|
||||
# Get the metadata
|
||||
metadata: ECConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
|
||||
assert encoder_cache is not None
|
||||
if metadata is None:
|
||||
logger.warning(
|
||||
(
|
||||
"In connector.start_load_caches, ",
|
||||
"but the connector metadata is None",
|
||||
)
|
||||
)
|
||||
return
|
||||
# Load the EC for each mm data
|
||||
for mm_data in metadata.mm_datas:
|
||||
if mm_data.mm_hash in encoder_cache:
|
||||
continue
|
||||
filename = self._generate_filename_debug(mm_data.mm_hash)
|
||||
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
|
||||
encoder_cache[mm_data.mm_hash] = ec_cache
|
||||
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
|
||||
|
||||
def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
|
||||
"""
|
||||
Save the encoder cache to the connector.
|
||||
|
||||
This method saves the encoder cache from the worker's local storage
|
||||
to shared storage or another external connector.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
mm_hash (str): The hash of the multimodal data whose cache is being saved.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
# Return if it is PD Instance
|
||||
if not self.is_producer:
|
||||
return
|
||||
filename = self._generate_filename_debug(mm_hash)
|
||||
ec_cache = encoder_cache[mm_hash]
|
||||
tensors = {"ec_cache": ec_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
logger.debug("Save cache successful for mm_hash %s", mm_hash)
|
||||
|
||||
def has_caches(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> list[bool]:
|
||||
"""
|
||||
Check if cache exist externally for each mm_data of request
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
|
||||
Returns:
|
||||
List of bool indicate that ith mm_data exist in cache or not
|
||||
"""
|
||||
result = []
|
||||
for feature in request.mm_features:
|
||||
result.append(self._found_match_for_mm_data(feature.identifier))
|
||||
return result
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
index: int,
|
||||
) -> None:
|
||||
"""
|
||||
Update ECConnector state after encoder cache allocation.
|
||||
"""
|
||||
mm_hash = request.mm_features[index].identifier
|
||||
num_encoder_token = request.get_num_encoder_tokens(index)
|
||||
# Insert mm_hash only if this block has not been recorded yet.
|
||||
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> ECConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
This only build for load mm_data only
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ECSharedStorageConnectorMetadata()
|
||||
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
|
||||
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
|
||||
self._mm_datas_need_loads.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_mm_data(self, mm_hash) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
filename = self._generate_filename_debug(mm_hash)
|
||||
return os.path.exists(filename)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
mm_hash: str,
|
||||
create_folder: bool = True, # <- now defaults to True
|
||||
) -> str:
|
||||
"""
|
||||
Return the folder in which the cache for this mm_hash lives.
|
||||
If `create_folder` is True (default) the directory is created
|
||||
recursively the first time it is needed.
|
||||
"""
|
||||
foldername = os.path.join(self._storage_path, mm_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(self, mm_hash: str) -> str:
|
||||
"""
|
||||
Return the full path of the safetensors file for this mm_hash.
|
||||
Ensures the parent directory exists because
|
||||
`_generate_foldername_debug` is called with its default
|
||||
(`create_folder=True`).
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created
|
||||
return os.path.join(foldername, "encoder_cache.safetensors")
|
||||
46
vllm/distributed/ec_transfer/ec_transfer_state.py
Normal file
46
vllm/distributed/ec_transfer/ec_transfer_state.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
_EC_CONNECTOR_AGENT: ECConnectorBase | None = None
|
||||
|
||||
|
||||
def get_ec_transfer() -> ECConnectorBase:
|
||||
assert _EC_CONNECTOR_AGENT is not None, "disaggregated EC cache is not initialized"
|
||||
return _EC_CONNECTOR_AGENT
|
||||
|
||||
|
||||
def has_ec_transfer() -> bool:
|
||||
return _EC_CONNECTOR_AGENT is not None
|
||||
|
||||
|
||||
def ensure_ec_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Initialize EC cache connector.
|
||||
"""
|
||||
|
||||
global _EC_CONNECTOR_AGENT
|
||||
|
||||
if vllm_config.ec_transfer_config is None:
|
||||
return
|
||||
|
||||
if (
|
||||
vllm_config.ec_transfer_config.is_ec_transfer_instance
|
||||
and _EC_CONNECTOR_AGENT is None
|
||||
):
|
||||
if envs.VLLM_USE_V1:
|
||||
_EC_CONNECTOR_AGENT = ECConnectorFactory.create_connector(
|
||||
config=vllm_config, role=ECConnectorRole.WORKER
|
||||
)
|
||||
else:
|
||||
raise ValueError("V0 is no longer supported")
|
||||
@@ -38,6 +38,7 @@ from vllm.config import (
|
||||
CompilationConfig,
|
||||
ConfigType,
|
||||
DeviceConfig,
|
||||
ECTransferConfig,
|
||||
EPLBConfig,
|
||||
KVEventsConfig,
|
||||
KVTransferConfig,
|
||||
@@ -527,6 +528,8 @@ class EngineArgs:
|
||||
kv_transfer_config: KVTransferConfig | None = None
|
||||
kv_events_config: KVEventsConfig | None = None
|
||||
|
||||
ec_transfer_config: ECTransferConfig | None = None
|
||||
|
||||
generation_config: str = ModelConfig.generation_config
|
||||
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
|
||||
override_generation_config: dict[str, Any] = get_field(
|
||||
@@ -1105,6 +1108,9 @@ class EngineArgs:
|
||||
"--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
|
||||
)
|
||||
vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
|
||||
vllm_group.add_argument(
|
||||
"--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--compilation-config", "-O", **vllm_kwargs["compilation_config"]
|
||||
)
|
||||
@@ -1676,6 +1682,7 @@ class EngineArgs:
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
ec_transfer_config=self.ec_transfer_config,
|
||||
additional_config=self.additional_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,10 +49,18 @@ def kernel_warmup(worker: "Worker"):
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
if not worker.model_runner.is_pooling_model and all(
|
||||
_is_flashinfer_backend(group.backend)
|
||||
for groups in worker.model_runner.attn_groups
|
||||
for group in groups
|
||||
# NOTE: we add check for empty attn_groups to avoid errors when
|
||||
# deploying models such as E instances and encoder-only models.
|
||||
# As for those models, worker.model_runner.attn_groups is empty.
|
||||
# This change is made during EPD feature development.
|
||||
if (
|
||||
not worker.model_runner.is_pooling_model
|
||||
and worker.model_runner.attn_groups
|
||||
and all(
|
||||
_is_flashinfer_backend(group.backend)
|
||||
for groups in worker.model_runner.attn_groups
|
||||
for group in groups
|
||||
)
|
||||
):
|
||||
logger.info("Warming up FlashInfer attention.")
|
||||
# Warmup with mixed batch containing both prefill and decode tokens
|
||||
|
||||
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
ECConnectorMetadata = object
|
||||
KVConnectorMetadata = object
|
||||
LoRARequest = object
|
||||
MultiModalFeatureSpec = object
|
||||
@@ -188,6 +190,9 @@ class SchedulerOutput:
|
||||
# KV Cache Connector metadata.
|
||||
kv_connector_metadata: KVConnectorMetadata | None = None
|
||||
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarOutput:
|
||||
|
||||
@@ -7,6 +7,11 @@ from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorMetadata,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
@@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
@@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface):
|
||||
self.kv_events_config,
|
||||
self.parallel_config.data_parallel_rank,
|
||||
)
|
||||
self.ec_connector = None
|
||||
if self.vllm_config.ec_transfer_config is not None:
|
||||
self.ec_connector = ECConnectorFactory.create_connector(
|
||||
config=self.vllm_config, role=ECConnectorRole.SCHEDULER
|
||||
)
|
||||
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||
@@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input: list[int] = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
if request.has_encoder_inputs:
|
||||
(
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
request.num_computed_tokens,
|
||||
@@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface):
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
scheduled_loras: set[int] = set()
|
||||
@@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
|
||||
# KVTransfer: loading remote KV, do not allocate for new work.
|
||||
@@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface):
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
num_computed_tokens,
|
||||
@@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface):
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
|
||||
# Allocate for external load encoder cache
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
@@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface):
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
@@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface):
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# Build the connector meta for ECConnector
|
||||
if self.ec_connector is not None:
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.ec_connector_metadata = ec_meta
|
||||
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
@@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens: int,
|
||||
num_new_tokens: int,
|
||||
encoder_compute_budget: int,
|
||||
) -> tuple[list[int], int, int]:
|
||||
) -> tuple[list[int], int, int, list[int]]:
|
||||
"""
|
||||
Determine which encoder inputs need to be scheduled in the current step,
|
||||
and update `num_new_tokens` and encoder token budget accordingly.
|
||||
@@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface):
|
||||
in this step, i.e.,
|
||||
[num_computed_tokens, num_computed_tokens + num_new_tokens).
|
||||
- It is not already computed and stored in the encoder cache.
|
||||
- It is not exist on remote encoder cache (via ECConnector)
|
||||
- There is sufficient encoder token budget to process it.
|
||||
- The encoder cache has space to store it.
|
||||
|
||||
@@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface):
|
||||
blocks and externally cached blocks (via KVConnector).
|
||||
"""
|
||||
if num_new_tokens == 0 or not request.has_encoder_inputs:
|
||||
return [], num_new_tokens, encoder_compute_budget
|
||||
return [], num_new_tokens, encoder_compute_budget, []
|
||||
encoder_inputs_to_schedule: list[int] = []
|
||||
mm_features = request.mm_features
|
||||
assert mm_features is not None
|
||||
assert len(mm_features) > 0
|
||||
external_load_encoder_input = []
|
||||
|
||||
# Check remote cache first
|
||||
if self.ec_connector is not None:
|
||||
remote_cache_has_item = self.ec_connector.has_caches(request)
|
||||
# NOTE: since scheduler operates on the request level (possibly with
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
@@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
external_load_encoder_input.append(i)
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
continue
|
||||
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
encoder_compute_budget -= num_encoder_tokens
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
@@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface):
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
)
|
||||
|
||||
def get_grammar_bitmask(
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, NamedTuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
else:
|
||||
@@ -136,6 +138,13 @@ class KVConnectorOutput:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECConnectorOutput:
|
||||
# [mm_hash]
|
||||
finished_sending: set[str] | None = None
|
||||
finished_recving: set[str] | None = None
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||
@dataclass
|
||||
@@ -167,6 +176,8 @@ class ModelRunnerOutput:
|
||||
|
||||
kv_connector_output: KVConnectorOutput | None = None
|
||||
|
||||
ec_connector_output: ECConnectorOutput | None = None
|
||||
|
||||
# req_id -> num_nans_in_logits
|
||||
num_nans_in_logits: dict[str, int] | None = None
|
||||
|
||||
@@ -192,6 +203,41 @@ class DraftTokenIds:
|
||||
draft_token_ids: list[list[int]]
|
||||
|
||||
|
||||
def make_empty_encoder_model_runner_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
"""
|
||||
Create a ModelRunnerOutput stub that contains the correct
|
||||
per-request bookkeeping but no generated data yet.
|
||||
"""
|
||||
if not scheduler_output.num_scheduled_tokens:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# Convert to list so we get a deterministic, indexable sequence
|
||||
req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())
|
||||
|
||||
# Give every request its own contiguous index
|
||||
req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}
|
||||
|
||||
# No tokens generated yet ⇒ one empty list per request
|
||||
sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
|
||||
|
||||
# Pooler outputs are not available yet ⇒ use None placeholders
|
||||
pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
kv_connector_output=None,
|
||||
ec_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[],
|
||||
req_id_to_index={},
|
||||
|
||||
87
vllm/v1/worker/ec_connector_model_runner_mixin.py
Normal file
87
vllm/v1/worker/ec_connector_model_runner_mixin.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define EC connector functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from typing import (
|
||||
TYPE_CHECKING, # noqa: UP035
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import ECConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU)
|
||||
class ECConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
def maybe_save_ec_to_connector(
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
mm_hash: str,
|
||||
):
|
||||
if not has_ec_transfer():
|
||||
logger.debug("Not have ec transfer please check")
|
||||
return
|
||||
connector = get_ec_transfer()
|
||||
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
|
||||
|
||||
@staticmethod
|
||||
def get_finished_ec_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
if has_ec_transfer():
|
||||
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def maybe_get_ec_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
) -> AbstractContextManager[ECConnectorOutput | None]:
|
||||
return (
|
||||
ECConnectorModelRunnerMixin._get_ec_connector_output(
|
||||
scheduler_output, encoder_cache, **kwargs
|
||||
)
|
||||
if has_ec_transfer()
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
# This context manager must be used within an active forward context.
|
||||
# It encapsulates the entire EC conector lifecycle within execute_model
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _get_ec_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
) -> Generator[ECConnectorOutput, None, None]:
|
||||
output = ECConnectorOutput()
|
||||
|
||||
ec_connector = get_ec_transfer()
|
||||
assert isinstance(ec_connector, ECConnectorBase)
|
||||
assert scheduler_output.ec_connector_metadata is not None
|
||||
ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata)
|
||||
|
||||
if not ec_connector.is_producer:
|
||||
ec_connector.start_load_caches(encoder_cache, **kwargs)
|
||||
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
output.finished_sending, output.finished_recving = (
|
||||
ec_connector.get_finished(scheduler_output.finished_req_ids)
|
||||
)
|
||||
|
||||
ec_connector.clear_connector_metadata()
|
||||
@@ -35,6 +35,7 @@ from vllm.config import (
|
||||
get_layers_from_vllm_config,
|
||||
update_config,
|
||||
)
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
@@ -114,12 +115,14 @@ from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
AsyncModelRunnerOutput,
|
||||
DraftTokenIds,
|
||||
ECConnectorOutput,
|
||||
KVConnectorOutput,
|
||||
LogprobsLists,
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
PoolerOutput,
|
||||
SamplerOutput,
|
||||
make_empty_encoder_model_runner_output,
|
||||
)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
@@ -134,6 +137,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
@@ -237,9 +241,12 @@ class ExecuteModelState(NamedTuple):
|
||||
sample_hidden_states: torch.Tensor
|
||||
aux_hidden_states: list[torch.Tensor] | None
|
||||
kv_connector_output: KVConnectorOutput | None
|
||||
ec_connector_output: ECConnectorOutput | None
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
class GPUModelRunner(
|
||||
LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -1873,6 +1880,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
logger.debug("Finish execute for mm hash %s", mm_hash)
|
||||
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
@@ -2191,20 +2200,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
torch.Tensor,
|
||||
IntermediateTensors | None,
|
||||
dict[str, Any],
|
||||
ECConnectorOutput | None,
|
||||
]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
is_first_rank = get_pp_group().is_first_rank
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
ec_connector_output = None
|
||||
|
||||
if (
|
||||
self.supports_mm_inputs
|
||||
and is_first_rank
|
||||
and not self.model_config.is_encoder_decoder
|
||||
):
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
|
||||
with self.maybe_get_ec_connector_output(
|
||||
scheduler_output,
|
||||
encoder_cache=self.encoder_cache,
|
||||
) as ec_connector_output:
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
|
||||
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
@@ -2284,6 +2300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
ec_connector_output,
|
||||
)
|
||||
|
||||
def _sample(
|
||||
@@ -2508,6 +2525,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Update persistent batch states.
|
||||
self._update_states(scheduler_output)
|
||||
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
with self.maybe_get_ec_connector_output(
|
||||
scheduler_output,
|
||||
encoder_cache=self.encoder_cache,
|
||||
) as ec_connector_output:
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
return make_empty_encoder_model_runner_output(scheduler_output)
|
||||
|
||||
if not num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if no work to do.
|
||||
@@ -2583,6 +2608,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
ec_connector_output,
|
||||
) = self._preprocess(
|
||||
scheduler_output, num_input_tokens, intermediate_tensors
|
||||
)
|
||||
@@ -2699,6 +2725,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
ec_connector_output,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -2720,6 +2747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
ec_connector_output,
|
||||
) = self.execute_model_state
|
||||
# Clear ephemeral state.
|
||||
self.execute_model_state = None
|
||||
@@ -2811,6 +2839,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=kv_connector_output,
|
||||
ec_connector_output=ec_connector_output
|
||||
if self.supports_mm_inputs
|
||||
else None,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
@@ -4797,7 +4828,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
return {}
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
@@ -887,3 +888,7 @@ def init_worker_distributed_environment(
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.decode_context_parallel_size,
|
||||
)
|
||||
|
||||
# Init ec connector here before KV caches caches init
|
||||
# NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
|
||||
ensure_ec_transfer_initialized(vllm_config)
|
||||
|
||||
Reference in New Issue
Block a user