Enable Cross layers KV cache layout at NIXL Connector V2 (#33339)

Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: liranschour <liranschour@users.noreply.github.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
liranschour
2026-02-05 12:17:02 +02:00
committed by GitHub
parent 3e472e81f9
commit 8322d4e47f
6 changed files with 339 additions and 89 deletions

View File

@@ -14,6 +14,7 @@ from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vll
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@@ -192,8 +193,6 @@ def copy_kv_blocks(
dst_device=dst_device,
)
from vllm.platforms import current_platform
if direction == "h2d":
copy_fn = current_platform.insert_blocks_to_device
else:
@@ -316,12 +315,14 @@ class TpKVTopology:
attn_backend: type[AttentionBackend]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
tensor_shape: torch.Size | None = None
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
_MOCK_BLOCK_SIZE = 16
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
@@ -329,6 +330,36 @@ class TpKVTopology:
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
self._cross_layers_blocks = False
if self.tensor_shape is not None:
self._cross_layers_blocks = (
len(self.tensor_shape) == len(kv_cache_shape) + 1
)
if self._cross_layers_blocks:
# prepend layers dimension
_MOCK_NUM_LAYERS = 80
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
# In case of cross layers permute kv_cache_shape according to
# stride_order to retrieve physical position of block_size
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
# In the default non-cross layers layout the block_size position
# is logical while in the cross layers case it is the physical
# position. This matches the shape of the actual kv cache tensors
# passed at register_kv_caches()/register_cross_layers_kv_cache()
block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE)
assert block_size_position is not None
self._block_size_position = -(len(kv_cache_shape) - block_size_position)
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@@ -336,7 +367,9 @@ class TpKVTopology:
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (self.is_mla or self.is_kv_layout_blocks_first)
return not (
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
)
@property
def tp_size(self) -> int:
@@ -346,6 +379,14 @@ class TpKVTopology:
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
@property
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks
@property
def block_size_position(self) -> int:
return self._block_size_position
def tp_ratio(
self,
remote_tp_size: int,

View File

@@ -54,7 +54,7 @@ from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
@@ -173,7 +173,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str
vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
@@ -216,6 +216,7 @@ def compute_nixl_compatibility_hash(
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
"cross_layers_blocks": cross_layers_blocks,
}
compat_hash = hash_factors(factors)
@@ -298,6 +299,26 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1):
@property
def prefer_cross_layer_blocks(self) -> bool:
backend = get_current_attn_backend(self._vllm_config)
if backend.get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
):
return False
# For now there is no benefit to run cross layers when backend
# does not support on HND
if get_kv_cache_layout() != "HND":
return False
extra_config = self.kv_transfer_config.kv_connector_extra_config
return (
str(extra_config.get("enable_cross_layers_blocks", "False")).lower()
== "true"
)
def __init__(
self,
vllm_config: VllmConfig,
@@ -309,7 +330,7 @@ class NixlConnector(KVConnectorBase_V1):
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: NixlConnectorScheduler | None = (
NixlConnectorScheduler(vllm_config, self.engine_id)
@@ -395,6 +416,16 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
self.connector_worker.register_kv_caches(kv_caches)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
@@ -976,20 +1007,17 @@ class NixlConnectorWorker:
# Get the attention backend from the first layer
# NOTE (NickLucche) models with multiple backends are not supported yet
backend = get_current_attn_backend(vllm_config)
self.attn_backend = get_current_attn_backend(vllm_config)
self.backend_name = backend.get_name()
self.backend_name = self.attn_backend.get_name()
self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name
)
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
# lazy initialized in register_kv_caches
self.compat_hash: str | None = None
self.kv_topo: TpKVTopology | None = None
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
@@ -998,17 +1026,12 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._physical_blocks_per_logical_kv_block = 1
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
def _nixl_handshake(
self,
host: str,
@@ -1022,6 +1045,7 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port)
@@ -1059,6 +1083,7 @@ class NixlConnectorWorker:
)
# Check compatibility hash BEFORE decoding agent metadata
assert self.compat_hash is not None
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
@@ -1267,6 +1292,20 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=next(iter(kv_caches.values())).shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), (
@@ -1301,29 +1340,21 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
if self.device_type == "cpu":
block_size_position = -2
else:
block_size_position = -2 if self.use_mla else -3
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
cache_list = (
cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches]
)
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
continue
kernel_block_size = cache.shape[block_size_position]
kernel_block_size = cache.shape[self.kv_topo.block_size_position]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
@@ -1385,6 +1416,7 @@ class NixlConnectorWorker:
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self.kv_topo.is_kv_layout_blocks_first:
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
@@ -1440,6 +1472,7 @@ class NixlConnectorWorker:
block_size=self.block_size,
)
# Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash,
@@ -1461,6 +1494,8 @@ class NixlConnectorWorker:
register another local_xfer_handler using remote block len to ensure
data copy correctness.
"""
assert self.kv_topo is not None
block_size_ratio = self.block_size // block_size
blocks_data = []
for i, base_addr in enumerate(self.seen_base_addresses):
@@ -1573,6 +1608,7 @@ class NixlConnectorWorker:
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
if engine_id not in self.dst_num_blocks:
@@ -1701,6 +1737,7 @@ class NixlConnectorWorker:
remote_engine_id = nixl_agent_meta.engine_id
assert self._tp_size[remote_engine_id] == remote_tp_size
assert self.kv_topo is not None
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
@@ -1837,6 +1874,7 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0:
return
assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
@@ -1856,7 +1894,7 @@ class NixlConnectorWorker:
block_size_ratio,
)
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
split_k_and_v = self.kv_topo.split_k_and_v
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
@@ -1881,6 +1919,7 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
assert self.kv_topo is not None
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
@@ -1950,6 +1989,7 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
assert self.kv_topo is not None
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
@@ -2109,7 +2149,7 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None
assert meta.remote is not None and self.kv_topo is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id
)
@@ -2182,6 +2222,7 @@ class NixlConnectorWorker:
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks(
@@ -2414,6 +2455,7 @@ class NixlConnectorWorker:
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
"""
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part).
block_len = self.block_len_per_layer[layer_idx] // 2