[P/D] NIXL Integration (#17751)
Signed-off-by: ApostaC <yihua98@uchicago.edu> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Brent Salisbury <bsalisbu@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Brent Salisbury <bsalisbu@redhat.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import inspect
|
||||
import json
|
||||
import re
|
||||
import textwrap
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
@@ -3438,6 +3439,9 @@ class KVTransferConfig:
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: str = str(uuid.uuid4())
|
||||
"""The engine id for KV transfers."""
|
||||
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache.
|
||||
Currently only support 'cuda'."""
|
||||
@@ -3448,7 +3452,7 @@ class KVTransferConfig:
|
||||
|
||||
kv_role: Optional[KVRole] = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'both'."""
|
||||
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
|
||||
|
||||
kv_rank: Optional[int] = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
|
||||
@@ -105,3 +105,8 @@ KVConnectorFactory.register_connector(
|
||||
"LMCacheConnectorV1",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
||||
"LMCacheConnectorV1")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"NixlConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||
"NixlConnector")
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorRole)
|
||||
KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
|
||||
|
||||
__all__ = [
|
||||
"KVConnectorRole",
|
||||
"KVConnectorBase_V1",
|
||||
]
|
||||
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]
|
||||
|
||||
@@ -23,7 +23,7 @@ The class provides the following primitives:
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -34,6 +34,7 @@ if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -47,12 +48,34 @@ class KVConnectorRole(enum.Enum):
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVTransferParams:
|
||||
"""
|
||||
Abstract KVTransferParams used to send KVTransfer
|
||||
parameters between instances of vLLM.
|
||||
|
||||
Specific instances of KVConnector customize this
|
||||
method for serializing / deserializing msgs sent
|
||||
via the HTTP protocol.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_raw_dict(
|
||||
raw_dict: Optional[dict[str,
|
||||
Any]]) -> Optional["KVTransferParams"]:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorMetadata:
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler KVConnector and Worker KVConnector.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
_KVTransferParams = KVTransferParams
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
logger.warning(
|
||||
@@ -66,6 +89,10 @@ class KVConnectorBase_V1(ABC):
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
@@ -97,9 +124,15 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return self._connector_metadata
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args: kv_caches:
|
||||
dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
@@ -162,15 +195,37 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous (recving, sending).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def set_kv_transfer_params(self, request: "Request"):
|
||||
"""Parse raw KV Transfer params."""
|
||||
assert request.kv_transfer_params is None
|
||||
kv_transfer_params = self._KVTransferParams.from_raw_dict(
|
||||
request.raw_kv_transfer_params)
|
||||
request.kv_transfer_params = kv_transfer_params
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> int:
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
@@ -181,13 +236,16 @@ class KVConnectorBase_V1(ABC):
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
@@ -207,3 +265,20 @@ class KVConnectorBase_V1(ABC):
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return False, None
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -92,7 +93,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> int:
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
@@ -107,9 +108,10 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
request, num_computed_tokens), False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
805
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
805
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
@@ -0,0 +1,805 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import contextlib
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Iterator
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
from typing_extensions import Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import round_down
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
logger.info("NIXL is available")
|
||||
except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NixlKVTransferParams(KVTransferParams):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_remote_prefill: bool,
|
||||
do_remote_decode: bool,
|
||||
remote_block_ids: Optional[list[int]] = None,
|
||||
remote_host: Optional[str] = None,
|
||||
remote_port: Optional[int] = None,
|
||||
remote_engine_id: Optional[str] = None,
|
||||
):
|
||||
self.do_remote_prefill = do_remote_prefill
|
||||
self.do_remote_decode = do_remote_decode
|
||||
self.remote_block_ids = remote_block_ids
|
||||
self.remote_host = remote_host
|
||||
self.remote_port = remote_port
|
||||
self.remote_engine_id = remote_engine_id
|
||||
|
||||
@staticmethod
|
||||
def from_raw_dict(
|
||||
raw_dict: Optional[dict[str,
|
||||
Any]]) -> Optional["NixlKVTransferParams"]:
|
||||
|
||||
# If no raw transfer params passed, return None.
|
||||
if raw_dict is None:
|
||||
return None
|
||||
|
||||
# Validate the request is formatted properly.
|
||||
if (("do_remote_prefill" not in raw_dict)
|
||||
or ("do_remote_decode" not in raw_dict)
|
||||
or ("remote_block_ids" not in raw_dict)
|
||||
or ("remote_host" not in raw_dict)
|
||||
or ("remote_port" not in raw_dict)
|
||||
or ("remote_engine_id" not in raw_dict)):
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer", raw_dict)
|
||||
return None
|
||||
|
||||
return NixlKVTransferParams(
|
||||
do_remote_prefill=raw_dict["do_remote_prefill"],
|
||||
do_remote_decode=raw_dict["do_remote_decode"],
|
||||
remote_block_ids=raw_dict["remote_block_ids"],
|
||||
remote_host=raw_dict["remote_host"],
|
||||
remote_port=raw_dict["remote_port"],
|
||||
remote_engine_id=raw_dict["remote_engine_id"],
|
||||
)
|
||||
|
||||
|
||||
class NixlAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True):
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
remote_engine_id: str
|
||||
|
||||
|
||||
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self):
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: str,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: NixlKVTransferParams,
|
||||
):
|
||||
assert request_id not in self.requests
|
||||
assert kv_transfer_params.remote_block_ids is not None
|
||||
assert kv_transfer_params.remote_engine_id is not None
|
||||
assert kv_transfer_params.remote_host is not None
|
||||
assert kv_transfer_params.remote_port is not None
|
||||
|
||||
self.requests[request_id] = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params.remote_block_ids,
|
||||
remote_engine_id=kv_transfer_params.remote_engine_id,
|
||||
remote_host=kv_transfer_params.remote_host,
|
||||
remote_port=kv_transfer_params.remote_port,
|
||||
)
|
||||
|
||||
|
||||
class NixlConnector(KVConnectorBase_V1):
|
||||
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
|
||||
NixlConnectorScheduler(vllm_config, str(self.engine_id))
|
||||
self.connector_worker: Optional[NixlConnectorWorker] = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished()
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""NixlConnector does not do layerwise saving."""
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""NixlConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
"""NixlConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
|
||||
class NixlConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
# No KVTransfer for this request.
|
||||
if request.kv_transfer_params is None:
|
||||
return 0, False
|
||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
if request.kv_transfer_params.do_remote_prefill:
|
||||
assert num_computed_tokens % self.block_size == 0
|
||||
rounded_num_prompt_tokens = round_down(
|
||||
len(request.prompt_token_ids), self.block_size)
|
||||
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
|
||||
return count, count > 0
|
||||
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
if request.kv_transfer_params is None:
|
||||
return
|
||||
|
||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||
if request.kv_transfer_params.do_remote_prefill:
|
||||
# NOTE(rob): if prompt < block_size, no remote blocks
|
||||
# since the remote only sends fully computed blocks, so
|
||||
# skip recving for this request. num_external_tokens
|
||||
# should be 0 if there are no remote blocks.
|
||||
if request.kv_transfer_params.remote_block_ids:
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, blocks.get_unhashed_block_ids())
|
||||
else:
|
||||
assert num_external_tokens == 0
|
||||
# Only trigger 1 KV transfer per request.
|
||||
request.kv_transfer_params.do_remote_prefill = False
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
meta = NixlConnectorMetadata()
|
||||
|
||||
# Loop through scheduled reqs and convert to ReqMeta.
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert isinstance(req.kv_transfer_params, NixlKVTransferParams)
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
|
||||
if request.kv_transfer_params is None:
|
||||
return False, None
|
||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||
|
||||
if ((not request.kv_transfer_params.do_remote_decode)
|
||||
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
|
||||
return False, None
|
||||
|
||||
# Get computed blocks.
|
||||
all_full = request.num_computed_tokens % self.block_size == 0
|
||||
computed_block_ids = (block_ids if all_full else block_ids[:-1])
|
||||
|
||||
# If prompt < block_size, no xfer so free blocks immediately.
|
||||
delay_free_blocks = len(computed_block_ids) > 0
|
||||
|
||||
return delay_free_blocks, NixlKVTransferParams(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
|
||||
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
|
||||
).__dict__
|
||||
|
||||
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, engine_id: str):
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
raise RuntimeError("NIXL is not available")
|
||||
logger.info("Initializing NIXL wrapper")
|
||||
logger.info("Initializing NIXL worker %s", engine_id)
|
||||
|
||||
# Agent.
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
# Map of engine_id -> agent_name.
|
||||
self._remote_agents: dict[str, str] = {}
|
||||
|
||||
# Metadata.
|
||||
self.engine_id = engine_id
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# KV Caches and nixl tracking data.
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Map of engine_id -> kv_caches_base_addr
|
||||
self.kv_caches_base_addr: dict[str, list[int]] = {}
|
||||
|
||||
# Number of NIXL regions. Currently one region per cache
|
||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||
self.num_regions = 0
|
||||
|
||||
# nixl_prepped_dlist_handle (int).
|
||||
self.src_xfer_side_handle: int = 0
|
||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||
self.dst_xfer_side_handles: dict[str, int] = {}
|
||||
|
||||
# Map of engine_id -> num_blocks.
|
||||
self.dst_num_blocks: dict[str, int] = {}
|
||||
self._registered_descs: list[Any] = []
|
||||
|
||||
# In progress transfers.
|
||||
# [req_id -> list[handle]]
|
||||
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
|
||||
list[Any])
|
||||
|
||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||
# transactions on ranks 1 to N-1.
|
||||
# [req_id -> count]
|
||||
self._done_recving_count: defaultdict[str,
|
||||
int] = defaultdict(lambda: 0)
|
||||
self._done_sending_count: defaultdict[str,
|
||||
int] = defaultdict(lambda: 0)
|
||||
|
||||
# Background thread for establishing new connections.
|
||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event, rank: int):
|
||||
"""Background thread for getting new NIXL handshakes."""
|
||||
# NOTE(rob): this is a simple implementation. We will move
|
||||
# to a better approach like an ETCD server in the future.
|
||||
|
||||
# NOTE(rob): to support heterogeneous TP, we will have to
|
||||
# move this into the scheduler rather than worker, since
|
||||
# each rank needs the metadata of all other ranks (whereas
|
||||
# in this setup, each rank only gets one other rank's meta.
|
||||
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = encoder.encode(metadata)
|
||||
size_in_bytes = len(encoded_data)
|
||||
logger.debug("Size of encoded NixlAgentMetadata: %s bytes",
|
||||
str(size_in_bytes))
|
||||
|
||||
# Listen for new requests for metadata.
|
||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
# NOTE(rob): we need each rank to have a unique port. This
|
||||
# hack to keeps us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
|
||||
path = f"tcp://{host}:{port}"
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
ready_event.set()
|
||||
while True:
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
if msg != GET_META_MSG:
|
||||
logger.warning(
|
||||
"Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data))
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int):
|
||||
"""Do a NIXL handshake with a remote instance."""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
# NOTE(rob): we need each rank to have a unique port. This is
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
path = f"tcp://{host}:{port + self.rank}"
|
||||
logger.debug("Querying metadata on path: %s", path)
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
# Send query for the request.
|
||||
sock.send(GET_META_MSG)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
got_metadata_time = time.perf_counter()
|
||||
|
||||
# Register Remote agent.
|
||||
self.add_remote_agent(metadata)
|
||||
setup_agent_time = time.perf_counter()
|
||||
|
||||
logger.debug("NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time)
|
||||
logger.debug("NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
use_mla = len(first_kv_cache.shape) == 3
|
||||
if use_mla:
|
||||
# MLA case.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
else:
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
|
||||
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
||||
# hybrid attn, etc
|
||||
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||
|
||||
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
|
||||
first_kv_cache.shape)
|
||||
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
caches_data = []
|
||||
|
||||
# Note(tms): I modified this from the original region setup code.
|
||||
# K and V are now in different regions. Advantage is that we can
|
||||
# elegantly support MLA and any cases where the K and V tensors
|
||||
# are non-contiguous (it's not locally guaranteed that they will be)
|
||||
# Disadvantage is that the encoded NixlAgentMetadata is now larger
|
||||
# (roughly 8KB vs 5KB).
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
cache_list = [cache_or_caches] if use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
caches_data.append((base_addr, region_len, self.rank, ""))
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
self.num_regions = len(caches_data)
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
logger.debug("Done registering descs")
|
||||
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
)
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(metadata, ready_event, self.rank),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener")
|
||||
self._nixl_handshake_listener_t.start()
|
||||
ready_event.wait()
|
||||
|
||||
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
|
||||
engine_id = nixl_agent_meta.engine_id
|
||||
if engine_id in self._remote_agents:
|
||||
return
|
||||
|
||||
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
|
||||
nixl_agent_meta.agent_metadata)
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
|
||||
# Create src descs and xfer side handles.
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(base_addr + block_offset, self.block_len, self.rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data), self.engine_id, self.rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs)
|
||||
|
||||
# Create dst descs and xfer side handles.
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[engine_id]:
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(base_addr + block_offset, self.block_len, self.rank))
|
||||
logger.debug("Created %s blocks for dst engine %s and rank %s",
|
||||
len(blocks_data), engine_id, self.rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.dst_xfer_side_handles[
|
||||
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
self._remote_agents[engine_id], descs)
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Get requests that are done sending or recving.
|
||||
|
||||
In TP>1 setup, each rank exchanges KVs with its counterpart
|
||||
ranks independently. get_finished() runs in a worker creates
|
||||
the done_sending and done_recving sets that are sent to the
|
||||
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
|
||||
are done before adding to finished, Ranks 1 to N-1 communicate
|
||||
to Rank 0 once their transaction is done + Rank 0 returns
|
||||
finished sets to Scheduler only once all ranks are done.
|
||||
"""
|
||||
done_sending = self._get_new_notifs()
|
||||
done_recving = self._pop_done_transfers(self._recving_transfers)
|
||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
"and %s requests done recving", self.rank, len(done_sending),
|
||||
len(done_recving))
|
||||
|
||||
if self.world_size == 1:
|
||||
return done_sending, done_recving
|
||||
|
||||
# Rank 0: get finished from all other ranks.
|
||||
if self.rank == 0:
|
||||
for req_id in done_sending:
|
||||
self._done_sending_count[req_id] += 1
|
||||
for req_id in done_recving:
|
||||
self._done_recving_count[req_id] += 1
|
||||
|
||||
# Keep track of how many other ranks have finished.
|
||||
other_ranks_finished_ids: list[str] = []
|
||||
for i in range(1, self.world_size):
|
||||
other_ranks_finished_ids.extend(
|
||||
self.tp_group.recv_object(src=i))
|
||||
for req_id in other_ranks_finished_ids:
|
||||
if (req_id in self._done_recving_count
|
||||
or req_id in self._recving_transfers):
|
||||
self._done_recving_count[req_id] += 1
|
||||
else:
|
||||
self._done_sending_count[req_id] += 1
|
||||
|
||||
# Return ids that finished on all ranks to the scheduler.
|
||||
all_done_recving: set[str] = set()
|
||||
for req_id in list(self._done_recving_count.keys()):
|
||||
if self._done_recving_count[req_id] == self.world_size:
|
||||
del self._done_recving_count[req_id]
|
||||
all_done_recving.add(req_id)
|
||||
|
||||
all_done_sending: set[str] = set()
|
||||
for req_id in list(self._done_sending_count.keys()):
|
||||
if self._done_sending_count[req_id] == self.world_size:
|
||||
del self._done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
|
||||
return all_done_sending, all_done_recving
|
||||
|
||||
# Ranks 1 to N-1: send finished ids to Rank 0.
|
||||
else:
|
||||
finished_req_ids = list(done_recving.union(done_sending))
|
||||
self.tp_group.send_object(finished_req_ids, dst=0)
|
||||
|
||||
# Unused as only Rank 0 results are sent to scheduler.
|
||||
return done_sending, done_recving
|
||||
|
||||
def _get_new_notifs(self) -> set[str]:
|
||||
"""Get req_ids which got a remote xfer message."""
|
||||
|
||||
notified_req_ids: set[str] = set()
|
||||
for req_ids in self.nixl_wrapper.get_new_notifs().values():
|
||||
for req_id in req_ids:
|
||||
assert req_id not in notified_req_ids
|
||||
notified_req_ids.add(req_id.decode("utf-8"))
|
||||
return notified_req_ids
|
||||
|
||||
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
|
||||
"""
|
||||
Pop completed xfers by checking for DONE state.
|
||||
Args:
|
||||
transfers: dict of req_id -> list[running_xfer]
|
||||
Returns:
|
||||
set of req_ids that have all done xfers
|
||||
"""
|
||||
done_req_ids: set[str] = set()
|
||||
for req_id, handles in list(transfers.items()):
|
||||
running_reqs = []
|
||||
for handle in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
# TODO ptarasiewicz: why abort is throwing errors?
|
||||
# self.nixl_wrapper.release_xfer_handle(handle)
|
||||
continue
|
||||
if xfer_state == "PROC":
|
||||
running_reqs.append(handle)
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s",
|
||||
xfer_state)
|
||||
if len(running_reqs) == 0:
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
else:
|
||||
transfers[req_id] = running_reqs
|
||||
return done_req_ids
|
||||
|
||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||
"""
|
||||
Start loading by triggering non-blocking nixl_xfer.
|
||||
We check for these trnxs to complete in each step().
|
||||
"""
|
||||
for req_id, meta in metadata.requests.items():
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine %s. "
|
||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
|
||||
meta.remote_engine_id, len(meta.local_block_ids),
|
||||
len(meta.remote_block_ids))
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote_engine_id,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_host=meta.remote_host,
|
||||
remote_port=meta.remote_port,
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
remote_host: str,
|
||||
remote_port: int,
|
||||
dst_engine_id: str,
|
||||
request_id: str,
|
||||
):
|
||||
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
|
||||
if dst_engine_id not in self._remote_agents:
|
||||
self._nixl_handshake(remote_host, remote_port)
|
||||
|
||||
# NOTE(rob): having the staging blocks be on the READER side is
|
||||
# not going to work well (since we will have to call rearrange tensors).
|
||||
# after we detect the txn is complete (which means we cannot make the
|
||||
# read trxn async easily). If we want to make "READ" happen cleanly,
|
||||
# then we will need to have the staging blocks on the remote side.
|
||||
|
||||
# NOTE(rob): according to nvidia the staging blocks are used to
|
||||
# saturate IB with heterogeneous TP sizes. We should remove the staging
|
||||
# blocks until we are ready.
|
||||
|
||||
# Full prefix cache hit: do not need to read remote blocks,
|
||||
# just notify P worker that we have the blocks we need.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
self.nixl_wrapper.send_notif(dst_engine_id,
|
||||
notif_msg=request_id.encode("utf-8"))
|
||||
return
|
||||
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
# Get side handles.
|
||||
local_xfer_side_handle = self.src_xfer_side_handle
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
||||
|
||||
# Get descs ids.
|
||||
remote_block_descs_ids = self._get_block_descs_ids(
|
||||
dst_engine_id, remote_block_ids)
|
||||
local_block_descs_ids = self._get_block_descs_ids(
|
||||
self.engine_id, local_block_ids)
|
||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||
|
||||
# Prepare transfer with Nixl.
|
||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||
"READ",
|
||||
local_xfer_side_handle,
|
||||
local_block_descs_ids,
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids,
|
||||
notif_msg=request_id.encode("utf-8"),
|
||||
)
|
||||
|
||||
# Begin async xfer.
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
|
||||
# Use handle to check completion in future step().
|
||||
self._recving_transfers[request_id].append(handle)
|
||||
|
||||
def _get_block_descs_ids(self, engine_id: str,
|
||||
block_ids: list[int]) -> list[int]:
|
||||
"""Get the descs ids for a set of block ids."""
|
||||
|
||||
# range(1) for MLA, range(2) otherwise.
|
||||
region_ids = range(self.num_regions)
|
||||
num_blocks = self.dst_num_blocks[engine_id]
|
||||
|
||||
# Compute the desc ids for each block.
|
||||
descs_ids: list[int] = []
|
||||
for reg_id in region_ids:
|
||||
for block_id in block_ids:
|
||||
descs_ids.append(reg_id * num_blocks + block_id)
|
||||
return descs_ids
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx: Optional[zmq.Context] = None
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
if socket_type == zmq.ROUTER:
|
||||
socket = ctx.socket(zmq.ROUTER)
|
||||
socket.bind(addr)
|
||||
elif socket_type == zmq.REQ:
|
||||
socket = ctx.socket(zmq.REQ)
|
||||
socket.connect(addr)
|
||||
else:
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
yield socket
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
@@ -17,6 +17,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -132,8 +133,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = \
|
||||
self._get_connector_metadata()
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, SharedStorageConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
@@ -225,7 +225,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> int:
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
@@ -239,7 +239,6 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
|
||||
# NOTE: in this debug implementation, we assume that the prompt is
|
||||
# cached_prompt + newly_generated_single_token
|
||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||
@@ -248,7 +247,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# with the block granularity. And it expects the returned blocks and
|
||||
# num_computed_tokens to also be aligned with the block granularity.
|
||||
if not self._found_match_for_request(request):
|
||||
return 0
|
||||
return 0, False
|
||||
|
||||
logger.info("External Cache Hit!")
|
||||
|
||||
@@ -257,9 +256,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(request.prompt_token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
@@ -403,6 +403,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit). Not supported by vLLM engine V0."))
|
||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.")
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
@@ -540,7 +543,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias)
|
||||
logit_bias=self.logit_bias,
|
||||
extra_args=({"kv_transfer_params": self.kv_transfer_params}
|
||||
if self.kv_transfer_params else None))
|
||||
|
||||
def _get_guided_json_from_tool(
|
||||
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||
@@ -848,6 +853,10 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified."))
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.")
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
# Default sampling parameters for completion requests
|
||||
@@ -973,7 +982,9 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids)
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=({"kv_transfer_params": self.kv_transfer_params}
|
||||
if self.kv_transfer_params else None))
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -1223,6 +1234,8 @@ class CompletionResponse(OpenAIBaseModel):
|
||||
model: str
|
||||
choices: list[CompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||
default=None, description="KVTransfer parameters.")
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
@@ -1412,6 +1425,8 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
||||
choices: list[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
|
||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||
default=None, description="KVTransfer parameters.")
|
||||
|
||||
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
|
||||
@@ -1086,6 +1086,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||
kv_transfer_params=final_res.kv_transfer_params,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -482,7 +482,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
kv_transfer_params=final_res_batch[0].kv_transfer_params)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
|
||||
10
vllm/envs.py
10
vllm/envs.py
@@ -112,6 +112,8 @@ if TYPE_CHECKING:
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@@ -747,6 +749,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# insecure method and it is needed for some reason.
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION":
|
||||
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
|
||||
|
||||
# IP address used for NIXL handshake between remote agents.
|
||||
"VLLM_NIXL_SIDE_CHANNEL_HOST":
|
||||
lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"),
|
||||
|
||||
# Port used for NIXL handshake between remote agents.
|
||||
"VLLM_NIXL_SIDE_CHANNEL_PORT":
|
||||
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@@ -11,10 +11,6 @@ import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata)
|
||||
|
||||
# KVConnector: trigger (possibly async) load before forward.
|
||||
# Each attn layer will block until the reading is complete.
|
||||
trigger_kv_transfer = (attn_metadata is not None
|
||||
and has_kv_transfer_group()
|
||||
and is_v1_kv_transfer_group())
|
||||
if trigger_kv_transfer:
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
kv_connector.start_load_kv(_forward_context)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any,
|
||||
"(batchsize, count, median_time(ms)): %s"),
|
||||
forward_stats)
|
||||
|
||||
# KVConnector: each attn layer triggers (possibly async) save.
|
||||
# Ensure all those operations complete before forward() is done.
|
||||
if trigger_kv_transfer:
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
kv_connector.wait_for_save()
|
||||
|
||||
_forward_context = prev_context
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
from collections.abc import MutableSequence
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Optional, Union
|
||||
from typing import Any, Generic, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
@@ -103,6 +103,7 @@ class RequestOutput:
|
||||
encoder_prompt_token_ids: The token IDs of the encoder prompt.
|
||||
None if decoder-only.
|
||||
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||
kv_transfer_params: The params for remote K/V transfer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -120,6 +121,7 @@ class RequestOutput:
|
||||
num_cached_tokens: Optional[int] = None,
|
||||
*,
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
@@ -133,11 +135,13 @@ class RequestOutput:
|
||||
self.encoder_prompt = encoder_prompt
|
||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||
self.num_cached_tokens = num_cached_tokens
|
||||
self.kv_transfer_params = kv_transfer_params
|
||||
|
||||
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
|
||||
"""Merge subsequent RequestOutput into this one"""
|
||||
|
||||
self.finished |= next_output.finished
|
||||
self.kv_transfer_params = next_output.kv_transfer_params
|
||||
|
||||
for next_completion in next_output.outputs:
|
||||
for i, completion in enumerate(self.outputs):
|
||||
|
||||
@@ -36,6 +36,12 @@ class KVCacheBlocks:
|
||||
"""Converts the KVCacheBlocks instance to a list of block IDs."""
|
||||
return [block.block_id for block in self.blocks]
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
return [
|
||||
block.block_id for block in self.blocks if block.block_hash is None
|
||||
]
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
|
||||
@@ -116,6 +122,12 @@ class KVCacheManager:
|
||||
- The number of computed tokens.
|
||||
"""
|
||||
|
||||
# Request already has blocks from async load via KVConnector.
|
||||
num_existing_blocks = len(
|
||||
self.single_type_manager.req_to_blocks[request.request_id])
|
||||
if num_existing_blocks > 0:
|
||||
return KVCacheBlocks.create_empty(), request.num_computed_tokens
|
||||
|
||||
# Prefix caching is disabled or
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (not self.enable_caching
|
||||
@@ -173,6 +185,7 @@ class KVCacheManager:
|
||||
num_new_tokens: int,
|
||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
) -> Optional[KVCacheBlocks]:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
@@ -186,6 +199,9 @@ class KVCacheManager:
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
as eagle.
|
||||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||
used by P/D when allocating blocks used in a KV transfer
|
||||
which will complete in a future step.
|
||||
|
||||
Blocks layout:
|
||||
```
|
||||
@@ -255,7 +271,9 @@ class KVCacheManager:
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot)
|
||||
|
||||
if not self.enable_caching:
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
# remote. Update state for locally cached blocks.
|
||||
if not self.enable_caching or delay_cache_blocks:
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
@@ -350,3 +368,16 @@ class KVCacheManager:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
||||
def get_block_ids(self, request_id: str) -> list[int]:
|
||||
"""Get the block ids of a request."""
|
||||
assert request_id in self.single_type_manager.req_to_blocks
|
||||
return [
|
||||
block.block_id
|
||||
for block in self.single_type_manager.req_to_blocks[request_id]
|
||||
]
|
||||
|
||||
def get_num_blocks(self, request_id: str):
|
||||
"""Get the number of blocks."""
|
||||
assert request_id in self.single_type_manager.req_to_blocks
|
||||
return len(self.single_type_manager.req_to_blocks[request_id])
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
@@ -137,3 +138,6 @@ class SchedulerInterface(ABC):
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
|
||||
return None
|
||||
|
||||
@@ -5,13 +5,15 @@ from __future__ import annotations
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
KVTransferParams)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
@@ -96,6 +98,9 @@ class Scheduler(SchedulerInterface):
|
||||
# This is flushed at the end of each scheduling step.
|
||||
self.finished_req_ids: set[str] = set()
|
||||
|
||||
# P/D: requests in process of recving KV transfers
|
||||
self.finished_recving_kv_req_ids: set[str] = set()
|
||||
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
# Request id -> deque of CachedRequestData
|
||||
@@ -307,6 +312,16 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
request = self.waiting[0]
|
||||
|
||||
# P/D: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
continue
|
||||
|
||||
# Skip request if the structured output request is still waiting
|
||||
# for FSM compilation.
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
@@ -330,49 +345,55 @@ class Scheduler(SchedulerInterface):
|
||||
continue
|
||||
|
||||
# Get already-cached tokens.
|
||||
computed_blocks, num_computed_tokens = \
|
||||
new_computed_blocks, num_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
num_external_tokens = (
|
||||
0 if self.connector is None else
|
||||
num_external_tokens, load_kv_async = (
|
||||
(0, False) if self.connector is None else
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens))
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens += num_external_tokens
|
||||
|
||||
# Number of tokens to be scheduled.
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed requests,
|
||||
# which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||
num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
# P/D: loading remote KV, do not allocate for new work.
|
||||
if load_kv_async:
|
||||
num_new_tokens = 0
|
||||
# Number of tokens to be scheduled.
|
||||
else:
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold
|
||||
< num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_tokens,
|
||||
computed_blocks,
|
||||
new_computed_blocks,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
@@ -384,10 +405,18 @@ class Scheduler(SchedulerInterface):
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
num_external_tokens,
|
||||
)
|
||||
|
||||
self.waiting.popleft()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
if request.use_structured_output:
|
||||
structured_output_request_ids[
|
||||
request.request_id] = req_index
|
||||
@@ -407,7 +436,7 @@ class Scheduler(SchedulerInterface):
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
computed_blocks + new_blocks).get_block_ids()
|
||||
self.kv_cache_manager.get_block_ids(request.request_id))
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
@@ -698,6 +727,7 @@ class Scheduler(SchedulerInterface):
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = generated_token_ids
|
||||
kv_transfer_params = None
|
||||
|
||||
# Append generated tokens and check for stop. Note that if
|
||||
# a request is still being prefilled, we expect the model runner
|
||||
@@ -709,7 +739,7 @@ class Scheduler(SchedulerInterface):
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
stopped = check_stop(request, self.max_model_len)
|
||||
if stopped:
|
||||
self._free_request(request)
|
||||
kv_transfer_params = self._free_request(request)
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
@@ -739,7 +769,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids:
|
||||
if new_token_ids or kv_transfer_params:
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs.append(
|
||||
EngineCoreOutput(
|
||||
@@ -749,7 +780,10 @@ class Scheduler(SchedulerInterface):
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events()))
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
))
|
||||
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
@@ -757,6 +791,9 @@ class Scheduler(SchedulerInterface):
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
|
||||
# P/D: update state for finished KV Transfers.
|
||||
self._update_from_kv_xfer_finished(model_runner_output)
|
||||
|
||||
# Return the cached request data to the queue so they can be reused.
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
||||
@@ -811,15 +848,27 @@ class Scheduler(SchedulerInterface):
|
||||
request.status = finished_status
|
||||
self._free_request(request)
|
||||
|
||||
def _free_request(self, request: Request) -> None:
|
||||
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
|
||||
|
||||
assert request.is_finished()
|
||||
self.kv_cache_manager.free(request)
|
||||
self.kv_cache_manager.free_block_hashes(request)
|
||||
|
||||
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
||||
self.encoder_cache_manager.free(request)
|
||||
self._cached_reqs_data.pop(request.request_id, None)
|
||||
del self.requests[request.request_id]
|
||||
self.finished_req_ids.add(request.request_id)
|
||||
|
||||
if not delay_free_blocks:
|
||||
self._free_blocks(request)
|
||||
|
||||
return kv_xfer_params
|
||||
|
||||
def _free_blocks(self, request: Request):
|
||||
assert request.is_finished()
|
||||
assert request.request_id not in self._cached_reqs_data
|
||||
self.kv_cache_manager.free(request)
|
||||
self.kv_cache_manager.free_block_hashes(request)
|
||||
del self.requests[request.request_id]
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return len(self.waiting) + len(self.running)
|
||||
|
||||
@@ -863,3 +912,70 @@ class Scheduler(SchedulerInterface):
|
||||
def shutdown(self) -> None:
|
||||
if self.kv_event_publisher:
|
||||
self.kv_event_publisher.shutdown()
|
||||
|
||||
########################################################################
|
||||
# P/D Related Methods
|
||||
########################################################################
|
||||
|
||||
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
|
||||
return self.connector
|
||||
|
||||
def _connector_finished(
|
||||
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
|
||||
"""Invoke the KV connector request_finished() method if applicable."""
|
||||
if self.connector is None:
|
||||
return False, None
|
||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
return self.connector.request_finished(request, block_ids)
|
||||
|
||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||
"""
|
||||
P/D: check if the request_id is finished_recving.
|
||||
|
||||
The finished_recving_kv_req_ids list is populated
|
||||
on the previous steps()'s update_from_output based
|
||||
on the worker side connector.
|
||||
|
||||
When the kv transfer is ready, we cache the blocks
|
||||
and the request state will be moved back to WAITING from
|
||||
WAITING_FOR_REMOTE_KV.
|
||||
"""
|
||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||
return False
|
||||
|
||||
# Now that the blocks are ready, actually cache them.
|
||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
num_computed_tokens = len(block_ids) * self.block_size
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
self.kv_cache_manager.single_type_manager.cache_blocks(
|
||||
request,
|
||||
self.kv_cache_manager.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens,
|
||||
)
|
||||
|
||||
# Update the request state for scheduling.
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
|
||||
# Return that we are ready.
|
||||
self.finished_recving_kv_req_ids.remove(request.request_id)
|
||||
return True
|
||||
|
||||
def _update_from_kv_xfer_finished(self,
|
||||
model_runner_output: ModelRunnerOutput):
|
||||
"""
|
||||
P/D: update the scheduler state based on the output.
|
||||
|
||||
The Worker side connectors add finished_recving and
|
||||
finished_sending reqs to the output.
|
||||
* if finished_sending: free the blocks
|
||||
# if finished_recving: add to state so we can
|
||||
scheduler the request during the next step.
|
||||
"""
|
||||
# P/D: update recv and send status from last step.
|
||||
for req_id in (model_runner_output.finished_recving or ()):
|
||||
logger.debug("Finished recving KV transfer for request %s", req_id)
|
||||
self.finished_recving_kv_req_ids.add(req_id)
|
||||
for req_id in (model_runner_output.finished_sending or ()):
|
||||
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||
self._free_blocks(self.requests[req_id])
|
||||
|
||||
@@ -105,6 +105,7 @@ class EngineCoreOutput(
|
||||
finish_reason: Optional[FinishReason] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
events: Optional[list[EngineCoreEvent]] = None
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
|
||||
@@ -182,6 +182,15 @@ class EngineCore:
|
||||
# Start grammar compilation asynchronously
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
|
||||
if req.raw_kv_transfer_params is not None:
|
||||
if (kv_connector := self.scheduler.get_kv_connector()):
|
||||
# Parse raw KV transfer params via connector.
|
||||
kv_connector.set_kv_transfer_params(req)
|
||||
else:
|
||||
logger.warning(
|
||||
"Got KVTransferParams, but no KVConnector found. "
|
||||
"Disabling KVTransfer for this request.")
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def abort_requests(self, request_ids: list[str]):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
@@ -146,6 +146,7 @@ class RequestState:
|
||||
new_token_ids: list[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
@@ -167,13 +168,15 @@ class RequestState:
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
return self._new_request_output(request_id, outputs, finished)
|
||||
return self._new_request_output(request_id, outputs, finished,
|
||||
kv_transfer_params)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: list[CompletionOutput],
|
||||
finished: bool,
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
) -> RequestOutput:
|
||||
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
@@ -189,6 +192,7 @@ class RequestState:
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=outputs,
|
||||
finished=finished,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
)
|
||||
|
||||
def _new_completion_output(
|
||||
@@ -335,6 +339,7 @@ class OutputProcessor:
|
||||
new_token_ids = engine_core_output.new_token_ids
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||
|
||||
req_state.is_prefilling = False
|
||||
|
||||
@@ -350,7 +355,8 @@ class OutputProcessor:
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := req_state.make_request_output(
|
||||
new_token_ids, finish_reason, stop_reason):
|
||||
new_token_ids, finish_reason, stop_reason,
|
||||
kv_transfer_params):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put(request_output)
|
||||
|
||||
@@ -100,12 +100,16 @@ class ModelRunnerOutput:
|
||||
# [prompt_len]
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
|
||||
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[],
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
finished_sending=None,
|
||||
finished_recving=None)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
@@ -61,6 +62,15 @@ class Request:
|
||||
self.num_encoder_inputs = len(self.mm_inputs)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# P/D: KV transfer parameters (raw and parsed).
|
||||
raw_params = (None if sampling_params.extra_args is None
|
||||
else sampling_params.extra_args.get(
|
||||
"kv_transfer_params", None))
|
||||
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
|
||||
# Each connector parses the raw dictionary and sets this
|
||||
# attr the first time that the request is processed.
|
||||
self.kv_transfer_params: Optional[KVTransferParams] = None
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
@@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum):
|
||||
"""Status of a request."""
|
||||
WAITING = enum.auto()
|
||||
WAITING_FOR_FSM = enum.auto()
|
||||
WAITING_FOR_REMOTE_KVS = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
PREEMPTED = enum.auto()
|
||||
# Note: anything after PREEMPTED will be considered
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import time
|
||||
import weakref
|
||||
@@ -17,8 +18,9 @@ from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -1065,15 +1067,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
@@ -1150,17 +1151,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
output = self.model(
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = output
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = output
|
||||
hidden_states = model_output
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
@@ -1341,8 +1348,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
@staticmethod
|
||||
def maybe_wait_for_kv_save() -> None:
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().wait_for_save()
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def generate_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
@@ -1813,6 +1868,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self),
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
|
||||
Reference in New Issue
Block a user