[Core] Implement disagg prefill by StatelessProcessGroup (#10502)

This PR provides initial support for single-node disaggregated prefill in 1P1D scenario.
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
This commit is contained in:
Kuntai Du
2024-12-01 19:01:00 -06:00
committed by GitHub
parent c11f172187
commit 0590ec3fd9
33 changed files with 2525 additions and 21 deletions

View File

@@ -0,0 +1,30 @@
# Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances.
Currently the main usecase is for disaggregated prefilling.
## Abstractions
The KV cache transfer contains three layer of abstractions:
- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`.
- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics).
- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`.
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
RDMA database).
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
## Disaggregated prefilling
The example usage is in [this file](../../../examples/disaggregated_prefill.sh).
Here is the diagram of how we run disaggretgated prefilling.
![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg)

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 139 KiB

View File

@@ -0,0 +1,122 @@
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union
import torch
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class KVConnectorBase(ABC):
"""
Abstract base class for a KV connector.
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
@abstractmethod
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
connector when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
"""
Send KV caches and hidden states to the connector.
This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.
Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.
Returns:
None
"""
raise NotImplementedError
@abstractmethod
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
"""
Receive KV caches and hidden states from the connector.
This method attempts to retrieve KV caches and hidden states for input
tokens. If all required KV caches and hidden states are received, it
will bypass model input, else it will fall back to normal vLLM model
forwarding.
Args:
model_executable (torch.nn.Module):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (List[torch.Tensor]):
List of KV caches for each layer.
Returns:
- hidden_or_intermediate_states (torch.Tensor or
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Optionally adjusted input metadata for re-execution when
`bypass_model_exec=False`.
"""
raise NotImplementedError

View File

@@ -0,0 +1,19 @@
from typing import TYPE_CHECKING
from .base import KVConnectorBase
if TYPE_CHECKING:
from vllm.config import VllmConfig
class KVConnectorFactory:
@staticmethod
def create_connector(rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
if config.kv_transfer_config.kv_connector == 'PyNcclConnector':
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
else:
raise ValueError(f"Unsupported connector type: "
f"{config.kv_connector}")

View File

@@ -0,0 +1,261 @@
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class SimpleConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.config = config.kv_transfer_config
logger.info("Initializing PyNcclConfig under kv_transfer_config %s",
self.config)
self.lookup_buffer_size = self.config.kv_buffer_size
self.producer_buffer: Optional[SimpleBuffer] = None
self.consumer_buffer: Optional[SimpleBuffer] = None
# 2 pipes for every rank in the world
port_offset_base = 2 * rank
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
self.producer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
self.producer_data_pipe,
self.config.kv_buffer_size)
else:
# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producder
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
self.consumer_buffer = SimpleBuffer(
self.consumer_signal_pipe,
self.consumer_data_pipe,
self.config.kv_buffer_size,
)
def select(self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
return self.consumer_buffer.drop_select(input_tokens, roi)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
_, _, num_heads, head_size = kv_cache[0].shape
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
ret = self.select(current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue
roi: torch.Tensor = ret[1]
keys: torch.Tensor = ret[2]
values: torch.Tensor = ret[3]
hidden: torch.Tensor = ret[4]
num_computed_tokens = roi.shape[0]
num_computed_tokens_list.append(num_computed_tokens)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.producer_data_pipe.close()
self.producer_signal_pipe.close()
self.consumer_data_pipe.close()
self.consumer_signal_pipe.close()

View File

@@ -0,0 +1,108 @@
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
All distributed communications are abstracted behind this class.
"""
from abc import ABC, abstractmethod
from typing import List, Optional
import torch
class KVLookupBufferBase(ABC):
"""
Abstract base class for a lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@abstractmethod
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
List[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
lookup buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -0,0 +1,242 @@
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
import time
from collections import deque
from typing import Deque, List, Optional, Union
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVLookupBufferBase)
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self.buffer: Deque[List[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(self, tokens_roi_sender: List[torch.Tensor],
tokens_roi_recver: List[torch.Tensor]):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
if tokens_recver is None:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return True
# Assuming that roi is a binary mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]
# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length],
tokens_recver[:min_length]):
return min_length
return 0
def _send_tensor_and_dec_size(self,
tensor: Optional[torch.Tensor]) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
if not data:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return 0
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor):
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone()
if isinstance(key, torch.Tensor):
key = key.clone()
if isinstance(value, torch.Tensor):
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
with self.buffer_lock:
for data in buffer_item:
self.buffer_size += self._get_element_size(data)
self.buffer.append(buffer_item)
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
assert roi is not None, "Please provide the roi when sending "\
"drop-select request"
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]
matched_length = 0
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with self.buffer_lock:
for _ in range(len(self.buffer)):
temp_length = self._matches(self.buffer[0],
tokens_roi_recver)
if temp_length > 0:
matched_length = temp_length
break
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
if matched_length > 0:
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
else:
# no match, just send None
for _ in range(5):
self.data_pipe.send_tensor(None)
except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
raise e
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \
"drop_select should be called by the KV cache consumer "\
"(e.g. the decode vLLM instance)"
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler)
self.request_handling_thread.start()
def close(self):
if hasattr(self, "request_handling_thread"
) and self.request_handling_thread is not None:
self.request_handling_thread.join()
else:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self.signal_pipe.send_tensor(self.end_signal)

View File

@@ -0,0 +1,65 @@
"""
This file defines an interface `KVPipeBase`
that provides an abstraction for sending and receiving tensors, or None, via
distributed communications.
All classes instantiated from this interface are assumed to be a FIFO pipe.
If your distributed communication platform already supports key-value lookup,
you can bypass this interface and directly start from `kv_lookup_buffer`.
"""
from abc import ABC, abstractmethod
from typing import Optional
import torch
class KVPipeBase(ABC):
"""
This class provides an interface for sending and receiving tensors, or
None, by distributed communications.
"""
@abstractmethod
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
TODO: add a `key` argument so that we can use traditional
key-value database as the distributed communication mechanism behind
the pipe.
Args:
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def recv_tensor(self) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline.
Returns:
Optional[torch.Tensor]: The tensor received from the pipeline. Can
be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the pipeline and release resources.
This method is responsible for closing the communication pipeline
and releasing any resources associated with it.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -0,0 +1,276 @@
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.
Key Features:
- Supports sending and receiving tensors with metadata
- Handles both CUDA and CPU device communications
- Implements a non-blocking tensor transfer mechanism
- Manages buffer size and provides backpressure control
- Supports distributed process groups with configurable parameters
"""
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple
import torch
from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
logger = init_logger(__name__)
class BrokenPipeException(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
Metadata = Dict[str, Optional[torch.Tensor]]
class PyNcclPipe(KVPipeBase):
METADATA_LENGTH = 16
MAX_TENSOR_DIMENSIONS = 14
METADATA_DTYPE = torch.int64
def __init__(self,
local_rank: int,
config: KVTransferConfig,
device: Optional[str] = None,
port_offset: int = 0):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
self.kv_parallel_size = self.config.kv_parallel_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
impl = self._get_device_send_recv_impl(self.group)
self.device_send_func, self.device_recv_func = impl
# set target rank
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
# transportation-related variables
self.transport_thread: Optional[ThreadPoolExecutor] = None
self.buffer_size = 0
self.buffer_size_lock = threading.Lock()
self.buffer_size_thresh = self.config.kv_buffer_size
def _get_device_send_recv_impl(
self, group: StatelessProcessGroup
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
[torch.Tensor, int], None]]:
send: Callable[[torch.Tensor, int], None]
recv: Callable[[torch.Tensor, int], None]
if self.device.type == "cuda":
# use PyNCCL for send / recv
comm = PyNcclCommunicator(group, device=self.local_rank)
comm.disabled = False
send, recv = comm.send, comm.recv # type: ignore
else:
# This send / recv implementation here is NOT intended to transfer
# KV caches (and should NOT be repurposed to transfer KV caches).
# Currently it is only used to transmit control-plane messages
# for PyNcclBuffer.
send = group.send_obj
def my_recv(x, src):
x[...] = group.recv_obj(src)
recv = my_recv
return send, recv
def _select_device(self, device: str):
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
"""
Create the metadata as a dictionary based on the input tensor.
Parameters:
- tensor: The input tensor or None if no tensor is provided.
Returns:
- metadata: A dictionary with the following keys:
- "dtype": The data type of the tensor or None.
- "shape": The shape of the tensor or None.
"""
if tensor is None:
return {"dtype": None, "shape": None}
else:
return {"dtype": tensor.dtype, "shape": tensor.shape}
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
"""
Create a buffer to receive the tensor based on the provided metadata.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape", describing
the tensor's data type and shape.
Returns:
- buffer: A tensor of the specified type and shape, allocated on
self.device.
"""
return torch.empty(metadata["shape"],
dtype=metadata["dtype"],
device=self.device)
def _send_metadata(self, metadata: Metadata):
"""
Send the metadata dictionary to the target rank.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
def _recv_metadata(self) -> Metadata:
"""
Receive the metadata dictionary from the target rank.
Returns:
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
target rank.
Parameters:
- tensor: The input tensor to be sent, or None if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
self._send_metadata(metadata)
if tensor is not None:
self.device_send_func(tensor.to(self.device),
self.target_rank_for_send)
def _recv_impl(self) -> Optional[torch.Tensor]:
"""
The actual implementation of receiving a tensor and its metadata from
the target rank.
Returns:
- buffer: The received tensor, or None if no tensor is received.
"""
metadata = self._recv_metadata()
if metadata["dtype"] is None:
return None
buffer = self._prepare_recv_buffer(metadata)
self.device_recv_func(buffer, self.target_rank_for_recv)
return buffer
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
tensor_size: int) -> None:
"""
Wrapper for _send_impl to handle exceptions and update buffer size.
"""
try:
self._send_impl(tensor)
with self.buffer_size_lock:
self.buffer_size -= tensor_size
except Exception as e:
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
torch.distributed.get_rank(), str(tensor), str(e))
import traceback
traceback.print_exc()
def block_if_full(self):
"""
Block the current thread if the buffer size is larger than the
threshold.
"""
while self.buffer_size > self.buffer_size_thresh:
logger.debug("KV cache transfer pipe is full. Waiting...")
time.sleep(0.05)
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Parameters:
- tensor: The tensor to send, or None if no tensor is being sent.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
if tensor is not None:
tensor_size = tensor.element_size() * tensor.numel()
else:
tensor_size = 0
self.block_if_full()
with self.buffer_size_lock:
self.buffer_size += tensor_size
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
tensor_size)
def recv_tensor(self) -> Optional[torch.Tensor]:
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
- tensor: The received tensor, or None if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
future = self.transport_thread.submit(self._recv_impl)
try:
tensor = future.result()
except Exception as e:
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)
logger.error("My device: %s", self.device)
import traceback
traceback.print_exc()
raise e
return tensor
def close(self):
"""
Close the pipe and release associated resources.
"""
if hasattr(self,
"transport_thread") and self.transport_thread is not None:
self.transport_thread.shutdown()

View File

@@ -0,0 +1,75 @@
"""A centralized entrypoint to perform distributed KV cache transfer.
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
1. `send_kv_caches_and_hidden_states`
2. `recv_kv_caches_and_hidden_states
"""
from typing import TYPE_CHECKING, List, Tuple, Union
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.config import VllmConfig
import torch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
class KVTransferAgent:
"""
A class designated for distributed KV transfer
Target use cases:
1. Disaggregated prefill
2. Remote KV cache storage
"""
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
self.config = config
if config.kv_transfer_config is None:
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
" cannot initialize KVConnector.")
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
"TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector(
rank, local_rank, config)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
self.connector.send_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches,
hidden_or_intermediate_states)
def close(self) -> None:
self.connector.close()
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
return self.connector.recv_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches)