[P/D] NIXL Integration (#17751)

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Brent Salisbury <bsalisbu@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Brent Salisbury <bsalisbu@redhat.com>
This commit is contained in:
Robert Shaw
2025-05-12 12:46:16 -04:00
committed by GitHub
parent 05a4324f8e
commit d19110204c
34 changed files with 2723 additions and 108 deletions

View File

@@ -8,6 +8,7 @@ import inspect
import json
import re
import textwrap
import uuid
import warnings
from collections import Counter
from contextlib import contextmanager
@@ -3438,6 +3439,9 @@ class KVTransferConfig:
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
engine_id: str = str(uuid.uuid4())
"""The engine id for KV transfers."""
kv_buffer_device: Optional[str] = "cuda"
"""The device used by kv connector to buffer the KV cache.
Currently only support 'cuda'."""
@@ -3448,7 +3452,7 @@ class KVTransferConfig:
kv_role: Optional[KVRole] = None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
are 'kv_producer', 'kv_consumer', and 'both'."""
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
kv_rank: Optional[int] = None
"""The rank of this vLLM instance in the KV cache transfer. Typical value:

View File

@@ -105,3 +105,8 @@ KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
"LMCacheConnectorV1")
KVConnectorFactory.register_connector(
"NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"NixlConnector")

View File

@@ -1,8 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole)
KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
__all__ = [
"KVConnectorRole",
"KVConnectorBase_V1",
]
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]

View File

@@ -23,7 +23,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional
import torch
@@ -34,6 +34,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@@ -47,12 +48,34 @@ class KVConnectorRole(enum.Enum):
WORKER = 1
class KVTransferParams:
"""
Abstract KVTransferParams used to send KVTransfer
parameters between instances of vLLM.
Specific instances of KVConnector customize this
method for serializing / deserializing msgs sent
via the HTTP protocol.
"""
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["KVTransferParams"]:
return None
@dataclass
class KVConnectorMetadata:
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
"""
pass
class KVConnectorBase_V1(ABC):
_KVTransferParams = KVTransferParams
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning(
@@ -66,6 +89,10 @@ class KVConnectorBase_V1(ABC):
def role(self) -> KVConnectorRole:
return self._role
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
@@ -97,9 +124,15 @@ class KVConnectorBase_V1(ABC):
"""
return self._connector_metadata
# ==============================
# Worker-side methods
# ==============================
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches:
dictionary of layer names, kv cache
"""
return
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext",
@@ -162,15 +195,37 @@ class KVConnectorBase_V1(ABC):
"""
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous (recving, sending).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
# ==============================
# Scheduler-side methods
# ==============================
def set_kv_transfer_params(self, request: "Request"):
"""Parse raw KV Transfer params."""
assert request.kv_transfer_params is None
kv_transfer_params = self._KVTransferParams.from_raw_dict(
request.raw_kv_transfer_params)
request.kv_transfer_params = kv_transfer_params
@abstractmethod
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
@@ -181,13 +236,16 @@ class KVConnectorBase_V1(ABC):
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
pass
@abstractmethod
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
@@ -207,3 +265,20 @@ class KVConnectorBase_V1(ABC):
scheduler_output (SchedulerOutput): the scheduler output object.
"""
pass
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return False, None

View File

@@ -13,6 +13,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@@ -92,7 +93,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> int:
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
@@ -107,9 +108,10 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
external KV cache beyond what is already computed.
"""
return self._lmcache_engine.get_num_new_matched_tokens(
request, num_computed_tokens)
request, num_computed_tokens), False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.

View File

@@ -0,0 +1,805 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import math
import threading
import time
import uuid
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator
import msgspec
import torch
import zmq
from typing_extensions import Optional
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.utils import round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
@dataclass
class NixlKVTransferParams(KVTransferParams):
def __init__(
self,
do_remote_prefill: bool,
do_remote_decode: bool,
remote_block_ids: Optional[list[int]] = None,
remote_host: Optional[str] = None,
remote_port: Optional[int] = None,
remote_engine_id: Optional[str] = None,
):
self.do_remote_prefill = do_remote_prefill
self.do_remote_decode = do_remote_decode
self.remote_block_ids = remote_block_ids
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_engine_id = remote_engine_id
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["NixlKVTransferParams"]:
# If no raw transfer params passed, return None.
if raw_dict is None:
return None
# Validate the request is formatted properly.
if (("do_remote_prefill" not in raw_dict)
or ("do_remote_decode" not in raw_dict)
or ("remote_block_ids" not in raw_dict)
or ("remote_host" not in raw_dict)
or ("remote_port" not in raw_dict)
or ("remote_engine_id" not in raw_dict)):
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", raw_dict)
return None
return NixlKVTransferParams(
do_remote_prefill=raw_dict["do_remote_prefill"],
do_remote_decode=raw_dict["do_remote_decode"],
remote_block_ids=raw_dict["remote_block_ids"],
remote_host=raw_dict["remote_host"],
remote_port=raw_dict["remote_port"],
remote_engine_id=raw_dict["remote_engine_id"],
)
class NixlAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
@dataclass
class ReqMeta:
local_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_engine_id: str
class NixlConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
def add_new_req(
self,
request_id: str,
local_block_ids: list[int],
kv_transfer_params: NixlKVTransferParams,
):
assert request_id not in self.requests
assert kv_transfer_params.remote_block_ids is not None
assert kv_transfer_params.remote_engine_id is not None
assert kv_transfer_params.remote_host is not None
assert kv_transfer_params.remote_port is not None
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params.remote_block_ids,
remote_engine_id=kv_transfer_params.remote_engine_id,
remote_host=kv_transfer_params.remote_host,
remote_port=kv_transfer_params.remote_port,
)
class NixlConnector(KVConnectorBase_V1):
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
NixlConnectorScheduler(vllm_config, str(self.engine_id))
self.connector_worker: Optional[NixlConnectorWorker] = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""NixlConnector does not do layerwise saving."""
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""NixlConnector does not save explicitly."""
pass
def wait_for_save(self):
"""NixlConnector does not save explicitly."""
pass
class NixlConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
logger.info("Initializing NIXL Scheduler %s", engine_id)
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
# No KVTransfer for this request.
if request.kv_transfer_params is None:
return 0, False
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
# Remote prefill: get all prompt blocks from remote.
if request.kv_transfer_params.do_remote_prefill:
assert num_computed_tokens % self.block_size == 0
rounded_num_prompt_tokens = round_down(
len(request.prompt_token_ids), self.block_size)
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
return count, count > 0
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
if request.kv_transfer_params is None:
return
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if request.kv_transfer_params.do_remote_prefill:
# NOTE(rob): if prompt < block_size, no remote blocks
# since the remote only sends fully computed blocks, so
# skip recving for this request. num_external_tokens
# should be 0 if there are no remote blocks.
if request.kv_transfer_params.remote_block_ids:
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
request.kv_transfer_params.do_remote_prefill = False
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = NixlConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert isinstance(req.kv_transfer_params, NixlKVTransferParams)
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if request.kv_transfer_params is None:
return False, None
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if ((not request.kv_transfer_params.do_remote_decode)
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
return False, None
# Get computed blocks.
all_full = request.num_computed_tokens % self.block_size == 0
computed_block_ids = (block_ids if all_full else block_ids[:-1])
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks = len(computed_block_ids) > 0
return delay_free_blocks, NixlKVTransferParams(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
).__dict__
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
logger.info("Initializing NIXL wrapper")
logger.info("Initializing NIXL worker %s", engine_id)
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> agent_name.
self._remote_agents: dict[str, str] = {}
# Metadata.
self.engine_id = engine_id
self.rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
# KV Caches and nixl tracking data.
self.kv_caches: dict[str, torch.Tensor] = {}
# Map of engine_id -> kv_caches_base_addr
self.kv_caches_base_addr: dict[str, list[int]] = {}
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
self.num_regions = 0
# nixl_prepped_dlist_handle (int).
self.src_xfer_side_handle: int = 0
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[str, int] = {}
# Map of engine_id -> num_blocks.
self.dst_num_blocks: dict[str, int] = {}
self._registered_descs: list[Any] = []
# In progress transfers.
# [req_id -> list[handle]]
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
list[Any])
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self._done_recving_count: defaultdict[str,
int] = defaultdict(lambda: 0)
self._done_sending_count: defaultdict[str,
int] = defaultdict(lambda: 0)
# Background thread for establishing new connections.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, rank: int):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach like an ETCD server in the future.
# NOTE(rob): to support heterogeneous TP, we will have to
# move this into the scheduler rather than worker, since
# each rank needs the metadata of all other ranks (whereas
# in this setup, each rank only gets one other rank's meta.
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata)
size_in_bytes = len(encoded_data)
logger.debug("Size of encoded NixlAgentMetadata: %s bytes",
str(size_in_bytes))
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
path = f"tcp://{host}:{port}"
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
while True:
identity, _, msg = sock.recv_multipart()
if msg != GET_META_MSG:
logger.warning(
"Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(self, host: str, port: int):
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = f"tcp://{host}:{port + self.rank}"
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
# Send query for the request.
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
# Register Remote agent.
self.add_remote_agent(metadata)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
# TODO(tms): Find a more robust way to detect and handle MLA
use_mla = len(first_kv_cache.shape) == 3
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
else:
# [2 (k and v), num_blocks, ...]
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
self.block_len = kv_elem_size * math.prod(block_shape)
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
first_kv_cache.shape)
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
# elegantly support MLA and any cases where the K and V tensors
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
caches_data.append((base_addr, region_len, self.rank, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
logger.debug("Done registering descs")
self._registered_descs.append(descs)
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.rank),
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()
ready_event.wait()
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
engine_id = nixl_agent_meta.engine_id
if engine_id in self._remote_agents:
return
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
# Create src descs and xfer side handles.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# Create dst descs and xfer side handles.
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
blocks_data = []
for base_addr in self.kv_caches_base_addr[engine_id]:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for dst engine %s and rank %s",
len(blocks_data), engine_id, self.rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id], descs)
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving.
In TP>1 setup, each rank exchanges KVs with its counterpart
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
are done before adding to finished, Ranks 1 to N-1 communicate
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
"""
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
if len(done_sending) > 0 or len(done_recving) > 0:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving", self.rank, len(done_sending),
len(done_recving))
if self.world_size == 1:
return done_sending, done_recving
# Rank 0: get finished from all other ranks.
if self.rank == 0:
for req_id in done_sending:
self._done_sending_count[req_id] += 1
for req_id in done_recving:
self._done_recving_count[req_id] += 1
# Keep track of how many other ranks have finished.
other_ranks_finished_ids: list[str] = []
for i in range(1, self.world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
if (req_id in self._done_recving_count
or req_id in self._recving_transfers):
self._done_recving_count[req_id] += 1
else:
self._done_sending_count[req_id] += 1
# Return ids that finished on all ranks to the scheduler.
all_done_recving: set[str] = set()
for req_id in list(self._done_recving_count.keys()):
if self._done_recving_count[req_id] == self.world_size:
del self._done_recving_count[req_id]
all_done_recving.add(req_id)
all_done_sending: set[str] = set()
for req_id in list(self._done_sending_count.keys()):
if self._done_sending_count[req_id] == self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)
return all_done_sending, all_done_recving
# Ranks 1 to N-1: send finished ids to Rank 0.
else:
finished_req_ids = list(done_recving.union(done_sending))
self.tp_group.send_object(finished_req_ids, dst=0)
# Unused as only Rank 0 results are sent to scheduler.
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]:
"""Get req_ids which got a remote xfer message."""
notified_req_ids: set[str] = set()
for req_ids in self.nixl_wrapper.get_new_notifs().values():
for req_id in req_ids:
assert req_id not in notified_req_ids
notified_req_ids.add(req_id.decode("utf-8"))
return notified_req_ids
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
transfers: dict of req_id -> list[running_xfer]
Returns:
set of req_ids that have all done xfers
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
running_reqs = []
for handle in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# TODO ptarasiewicz: why abort is throwing errors?
# self.nixl_wrapper.release_xfer_handle(handle)
continue
if xfer_state == "PROC":
running_reqs.append(handle)
else:
raise RuntimeError("Transfer failed with state %s",
xfer_state)
if len(running_reqs) == 0:
done_req_ids.add(req_id)
del transfers[req_id]
else:
transfers[req_id] = running_reqs
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
"""
Start loading by triggering non-blocking nixl_xfer.
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.requests.items():
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_host=meta.remote_host,
remote_port=meta.remote_port,
)
def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_host: str,
remote_port: int,
dst_engine_id: str,
request_id: str,
):
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
if dst_engine_id not in self._remote_agents:
self._nixl_handshake(remote_host, remote_port)
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
# read trxn async easily). If we want to make "READ" happen cleanly,
# then we will need to have the staging blocks on the remote side.
# NOTE(rob): according to nvidia the staging blocks are used to
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
self.nixl_wrapper.send_notif(dst_engine_id,
notif_msg=request_id.encode("utf-8"))
return
# Partial prefix cache hit: just read uncomputed blocks.
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# Get descs ids.
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=request_id.encode("utf-8"),
)
# Begin async xfer.
self.nixl_wrapper.transfer(handle)
# Use handle to check completion in future step().
self._recving_transfers[request_id].append(handle)
def _get_block_descs_ids(self, engine_id: str,
block_ids: list[int]) -> list[int]:
"""Get the descs ids for a set of block ids."""
# range(1) for MLA, range(2) otherwise.
region_ids = range(self.num_regions)
num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block.
descs_ids: list[int] = []
for reg_id in region_ids:
for block_id in block_ids:
descs_ids.append(reg_id * num_blocks + block_id)
return descs_ids
@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx: Optional[zmq.Context] = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
if socket_type == zmq.ROUTER:
socket = ctx.socket(zmq.ROUTER)
socket.bind(addr)
elif socket_type == zmq.REQ:
socket = ctx.socket(zmq.REQ)
socket.connect(addr)
else:
raise ValueError(f"Unexpected socket type: {socket_type}")
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)

View File

@@ -17,6 +17,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@@ -132,8 +133,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
@@ -225,7 +225,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> int:
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
@@ -239,7 +239,6 @@ class SharedStorageConnector(KVConnectorBase_V1):
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
@@ -248,7 +247,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0
return 0, False
logger.info("External Cache Hit!")
@@ -257,9 +256,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens
return num_tokens_to_check - num_computed_tokens, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.

View File

@@ -403,6 +403,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")
# doc: end-chat-completion-extra-params
@@ -540,7 +543,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias)
logit_bias=self.logit_bias,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
@@ -848,6 +853,10 @@ class CompletionRequest(OpenAIBaseModel):
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")
# doc: end-completion-extra-params
# Default sampling parameters for completion requests
@@ -973,7 +982,9 @@ class CompletionRequest(OpenAIBaseModel):
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
allowed_token_ids=self.allowed_token_ids,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
@model_validator(mode="before")
@classmethod
@@ -1223,6 +1234,8 @@ class CompletionResponse(OpenAIBaseModel):
model: str
choices: list[CompletionResponseChoice]
usage: UsageInfo
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")
class CompletionResponseStreamChoice(OpenAIBaseModel):
@@ -1412,6 +1425,8 @@ class ChatCompletionResponse(OpenAIBaseModel):
choices: list[ChatCompletionResponseChoice]
usage: UsageInfo
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")
class DeltaMessage(OpenAIBaseModel):

View File

@@ -1086,6 +1086,7 @@ class OpenAIServingChat(OpenAIServing):
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
)
return response

View File

@@ -482,7 +482,7 @@ class OpenAIServingCompletion(OpenAIServing):
model=model_name,
choices=choices,
usage=usage,
)
kv_transfer_params=final_res_batch[0].kv_transfer_params)
def _create_completion_logprobs(
self,

View File

@@ -112,6 +112,8 @@ if TYPE_CHECKING:
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
def get_default_cache_root():
@@ -747,6 +749,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# insecure method and it is needed for some reason.
"VLLM_ALLOW_INSECURE_SERIALIZATION":
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
# IP address used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_HOST":
lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"),
# Port used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_PORT":
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
}
# end-env-vars-definition

View File

@@ -11,10 +11,6 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger
if TYPE_CHECKING:
@@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
# KVConnector: trigger (possibly async) load before forward.
# Each attn layer will block until the reading is complete.
trigger_kv_transfer = (attn_metadata is not None
and has_kv_transfer_group()
and is_v1_kv_transfer_group())
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context)
try:
yield
finally:
@@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any,
"(batchsize, count, median_time(ms)): %s"),
forward_stats)
# KVConnector: each attn layer triggers (possibly async) save.
# Ensure all those operations complete before forward() is done.
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.wait_for_save()
_forward_context = prev_context

View File

@@ -4,7 +4,7 @@ import time
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass
from typing import Generic, Optional, Union
from typing import Any, Generic, Optional, Union
import torch
from typing_extensions import TypeVar, deprecated
@@ -103,6 +103,7 @@ class RequestOutput:
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
kv_transfer_params: The params for remote K/V transfer.
"""
def __init__(
@@ -120,6 +121,7 @@ class RequestOutput:
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
@@ -133,11 +135,13 @@ class RequestOutput:
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""
self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params
for next_completion in next_output.outputs:
for i, completion in enumerate(self.outputs):

View File

@@ -36,6 +36,12 @@ class KVCacheBlocks:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
return [
block.block_id for block in self.blocks if block.block_hash is None
]
class KVCacheManager:
@@ -116,6 +122,12 @@ class KVCacheManager:
- The number of computed tokens.
"""
# Request already has blocks from async load via KVConnector.
num_existing_blocks = len(
self.single_type_manager.req_to_blocks[request.request_id])
if num_existing_blocks > 0:
return KVCacheBlocks.create_empty(), request.num_computed_tokens
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
@@ -173,6 +185,7 @@ class KVCacheManager:
num_new_tokens: int,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.
@@ -186,6 +199,9 @@ class KVCacheManager:
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
Blocks layout:
```
@@ -255,7 +271,9 @@ class KVCacheManager:
new_blocks = self.single_type_manager.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
if not self.enable_caching:
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return KVCacheBlocks(new_blocks)
# Speculated tokens might be rejected in the future, so we does
@@ -350,3 +368,16 @@ class KVCacheManager:
A list of KV cache events.
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[int]:
"""Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks
return [
block.block_id
for block in self.single_type_manager.req_to_blocks[request_id]
]
def get_num_blocks(self, request_id: str):
"""Get the number of blocks."""
assert request_id in self.single_type_manager.req_to_blocks
return len(self.single_type_manager.req_to_blocks[request_id])

View File

@@ -4,6 +4,7 @@ from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
@@ -137,3 +138,6 @@ class SchedulerInterface(ABC):
def shutdown(self) -> None:
"""Shutdown the scheduler."""
raise NotImplementedError
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
return None

View File

@@ -5,13 +5,15 @@ from __future__ import annotations
import time
from collections import defaultdict, deque
from collections.abc import Iterable
from typing import Optional, Union
from typing import Any, Optional, Union
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole,
KVTransferParams)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@@ -96,6 +98,9 @@ class Scheduler(SchedulerInterface):
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
# P/D: requests in process of recving KV transfers
self.finished_recving_kv_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> deque of CachedRequestData
@@ -307,6 +312,16 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0]
# P/D: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
@@ -330,49 +345,55 @@ class Scheduler(SchedulerInterface):
continue
# Get already-cached tokens.
computed_blocks, num_computed_tokens = \
new_computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
num_external_tokens = (
0 if self.connector is None else
num_external_tokens, load_kv_async = (
(0, False) if self.connector is None else
self.connector.get_num_new_matched_tokens(
request, num_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens += num_external_tokens
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_tokens,
computed_blocks,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,
)
if new_blocks is None:
# The request cannot be scheduled.
@@ -384,10 +405,18 @@ class Scheduler(SchedulerInterface):
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_tokens,
)
self.waiting.popleft()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.appendleft(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
if request.use_structured_output:
structured_output_request_ids[
request.request_id] = req_index
@@ -407,7 +436,7 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = (
computed_blocks + new_blocks).get_block_ids()
self.kv_cache_manager.get_block_ids(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
@@ -698,6 +727,7 @@ class Scheduler(SchedulerInterface):
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
@@ -709,7 +739,7 @@ class Scheduler(SchedulerInterface):
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
kv_transfer_params = self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
@@ -739,7 +769,8 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids:
if new_token_ids or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
@@ -749,7 +780,10 @@ class Scheduler(SchedulerInterface):
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
@@ -757,6 +791,9 @@ class Scheduler(SchedulerInterface):
if not stopped:
new_running.append(request)
# P/D: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
@@ -811,15 +848,27 @@ class Scheduler(SchedulerInterface):
request.status = finished_status
self._free_request(request)
def _free_request(self, request: Request) -> None:
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
if not delay_free_blocks:
self._free_blocks(request)
return kv_xfer_params
def _free_blocks(self, request: Request):
assert request.is_finished()
assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
@@ -863,3 +912,70 @@ class Scheduler(SchedulerInterface):
def shutdown(self) -> None:
if self.kv_event_publisher:
self.kv_event_publisher.shutdown()
########################################################################
# P/D Related Methods
########################################################################
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
return self.connector
def _connector_finished(
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
"""Invoke the KV connector request_finished() method if applicable."""
if self.connector is None:
return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
P/D: check if the request_id is finished_recving.
The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
on the worker side connector.
When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
"""
if request.request_id not in self.finished_recving_kv_req_ids:
return False
# Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.single_type_manager.cache_blocks(
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,
)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def _update_from_kv_xfer_finished(self,
model_runner_output: ModelRunnerOutput):
"""
P/D: update the scheduler state based on the output.
The Worker side connectors add finished_recving and
finished_sending reqs to the output.
* if finished_sending: free the blocks
# if finished_recving: add to state so we can
scheduler the request during the next step.
"""
# P/D: update recv and send status from last step.
for req_id in (model_runner_output.finished_recving or ()):
logger.debug("Finished recving KV transfer for request %s", req_id)
self.finished_recving_kv_req_ids.add(req_id)
for req_id in (model_runner_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])

View File

@@ -105,6 +105,7 @@ class EngineCoreOutput(
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
@property
def finished(self) -> bool:

View File

@@ -182,6 +182,15 @@ class EngineCore:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
if req.raw_kv_transfer_params is not None:
if (kv_connector := self.scheduler.get_kv_connector()):
# Parse raw KV transfer params via connector.
kv_connector.set_kv_transfer_params(req)
else:
logger.warning(
"Got KVTransferParams, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req)
def abort_requests(self, request_ids: list[str]):

View File

@@ -3,7 +3,7 @@
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional, Union
from typing import Any, Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind
@@ -146,6 +146,7 @@ class RequestState:
new_token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> Optional[RequestOutput]:
finished = finish_reason is not None
@@ -167,13 +168,15 @@ class RequestState:
if not outputs:
return None
return self._new_request_output(request_id, outputs, finished)
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> RequestOutput:
if self.output_kind == RequestOutputKind.DELTA:
@@ -189,6 +192,7 @@ class RequestState:
prompt_logprobs=prompt_logprobs,
outputs=outputs,
finished=finished,
kv_transfer_params=kv_transfer_params,
)
def _new_completion_output(
@@ -335,6 +339,7 @@ class OutputProcessor:
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
req_state.is_prefilling = False
@@ -350,7 +355,8 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason):
new_token_ids, finish_reason, stop_reason,
kv_transfer_params):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)

View File

@@ -100,12 +100,16 @@ class ModelRunnerOutput:
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
finished_sending=None,
finished_recving=None)

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import enum
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
@@ -61,6 +62,15 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: KV transfer parameters (raw and parsed).
raw_params = (None if sampling_params.extra_args is None
else sampling_params.extra_args.get(
"kv_transfer_params", None))
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
# Each connector parses the raw dictionary and sets this
# attr the first time that the request is processed.
self.kv_transfer_params: Optional[KVTransferParams] = None
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes:
@@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import gc
import time
import weakref
@@ -17,8 +18,9 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
@@ -1065,15 +1067,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata(
scheduler_output.kv_connector_metadata)
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
@@ -1150,17 +1151,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
output = self.model(
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = output
hidden_states, aux_hidden_states = model_output
else:
hidden_states = output
hidden_states = model_output
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
@@ -1341,8 +1348,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def generate_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
@@ -1813,6 +1868,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self),
kv_cache_config.kv_cache_groups[0].kv_cache_spec,