[PD][Nixl] Add support for hybrid SSM-FA models (#36687)

This commit is contained in:
Nicolò Lucchesi
2026-03-16 19:58:06 +01:00
committed by GitHub
parent c88ea8338b
commit f5c081d432
7 changed files with 584 additions and 163 deletions

View File

@@ -16,10 +16,12 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.v1.kv_cache_interface import KVCacheSpec
logger = init_logger(__name__)
@@ -328,22 +330,26 @@ class TpKVTopology:
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
attn_backends: list[type[AttentionBackend]]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
tensor_shape: torch.Size | None = None
is_mamba: bool = False
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
_MOCK_BLOCK_SIZE = 16
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
)
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
attn_backend = self.attn_backends[0]
if not self.is_mamba:
_MOCK_BLOCK_SIZE = 16
kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
)
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
# Hybrid SSM models assume a single blocks_first layout
self._is_kv_layout_blocks_first = self.is_mamba or (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
@@ -360,7 +366,7 @@ class TpKVTopology:
_MOCK_NUM_LAYERS = 80
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
except (AttributeError, NotImplementedError):
@@ -483,6 +489,30 @@ class TpKVTopology:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_ranks(remote_tp_size)
def get_transfer_cache_regions(
self, cache: torch.Tensor, layer_spec: "KVCacheSpec"
) -> list[torch.Tensor] | torch.Tensor:
"""Return the cache tensor(s) to register as NIXL memory regions,
also accounting for hybrid SSM models specificities.
"""
if isinstance(layer_spec, MambaSpec):
# Register the whole kv cache shared tensor, including SSM/Conv. This is
# similar to FI with the difference that SSM/Conv have different sizes
conv, ssm = cache
return [conv]
# Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`.
if self.is_mamba and cache.shape[0] == 2:
# When MAMBA is present, all backends are blocks first, so that blocks
# can be shared between attention layers and mamba layers. Runner
# `_update_hybrid_attention_mamba_layout` already adjusted strides
# for FlashAttn-like backends so its num_blocks first.
# Swap [2<>num_blocks] dims to get required layout for hybrid SSM.
cache = cache.transpose(0, 1)
# Regular case: backends like FA register K/V in separate regions
return cache if self.split_k_and_v else [cache]
def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None

View File

@@ -564,7 +564,7 @@ class MooncakeConnectorWorker:
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
attn_backends=[backend],
)
self.async_zmq_ctx = zmq.asyncio.Context()

View File

@@ -59,7 +59,12 @@ from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
MambaSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.utils import select_common_block_size
@@ -159,6 +164,7 @@ class NixlAgentMetadata:
block_lens: list[int]
kv_cache_layout: str
block_size: int
ssm_sizes: tuple[int, int]
@dataclass
@@ -310,6 +316,15 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
@property
def prefer_cross_layer_blocks(self) -> bool:
if any(
[
isinstance(group.kv_cache_spec, MambaSpec)
for group in self.kv_cache_config.kv_cache_groups
]
):
# Hybrid SSM models do not yet support cross-layer layout
return False
backend = get_current_attn_backend(self._vllm_config)
if backend.get_name() not in (
"FLASH_ATTN",
@@ -335,12 +350,9 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
for group in kv_cache_config.kv_cache_groups:
if isinstance(group.kv_cache_spec, MambaSpec):
raise ValueError("NixlConnector does not support Mamba models.")
self.kv_cache_config = kv_cache_config
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER:
@@ -434,11 +446,7 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
self.connector_worker.register_kv_caches(kv_caches)
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
@@ -962,6 +970,40 @@ class NixlConnectorWorker:
)
)
self.kv_cache_config = kv_cache_config
self._layer_specs = {
layer: group.kv_cache_spec
for group in kv_cache_config.kv_cache_groups
for layer in group.layer_names
}
self.hma_group_size = len(kv_cache_config.kv_cache_tensors)
# Mamba metadata
self._is_mamba_group = [
isinstance(group.kv_cache_spec, MambaSpec)
for group in kv_cache_config.kv_cache_groups
]
mamba_ssm_size = (0, 0)
self._has_mamba = any(self._is_mamba_group)
if self._has_mamba:
assert self._is_hma_required
mamba_spec = next(
spec
for spec in self._layer_specs.values()
if isinstance(spec, MambaSpec)
)
conv_nbytes, ssm_nbytes = (
torch.tensor([], dtype=mamba_spec.dtypes[0]).element_size(), # type: ignore[misc]
torch.tensor([], dtype=mamba_spec.dtypes[1]).element_size(), # type: ignore[misc]
)
conv_shape, ssm_shape = (
torch.Size(mamba_spec.shapes[0]),
torch.Size(mamba_spec.shapes[1]),
)
mamba_ssm_size = (
conv_shape.numel() * conv_nbytes,
ssm_shape.numel() * ssm_nbytes,
)
self._mamba_ssm_size = mamba_ssm_size
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
@@ -1106,9 +1148,9 @@ class NixlConnectorWorker:
# Get the attention backend from the first layer
# NOTE (NickLucche) models with multiple backends are not supported yet
self.attn_backend = get_current_attn_backend(vllm_config)
self.attn_backends = get_current_attn_backends(vllm_config)
self.backend_name = self.attn_backends[0].get_name()
self.backend_name = self.attn_backend.get_name()
self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.info("Detected attention backend %s", self.backend_name)
@@ -1135,6 +1177,8 @@ class NixlConnectorWorker:
def _sync_block_size_with_kernel(self) -> None:
backends = get_current_attn_backends(self.vllm_config)
kernel_block_size = select_common_block_size(self.block_size, backends)
# Number of blocks not accounting for kernel block mismatches
self._logical_num_blocks = self.num_blocks
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
@@ -1428,9 +1472,19 @@ class NixlConnectorWorker:
fut.add_done_callback(request_ready)
def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None:
"""Register a cross-layers KV cache tensor with NIXL.
`use_uniform_kv_cache()` guarantees a single KV cache group whose
layers all share the same `AttentionSpec`, so any layer name from
`_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
"""
first_layer = next(iter(self._layer_specs))
# Forwarding a real layer name rather than a synthetic key
self.register_kv_caches({first_layer: kv_cache})
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
@@ -1438,8 +1492,12 @@ class NixlConnectorWorker:
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=next(iter(kv_caches.values())).shape,
attn_backends=self.attn_backends,
# SSM States come in tuples (ssm, conv)
tensor_shape=next(iter(kv_caches.values())).shape
if not self._has_mamba
else None,
is_mamba=self._has_mamba,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
@@ -1481,12 +1539,50 @@ class NixlConnectorWorker:
# to better exploit the memory layout (ie num_blocks is the first dim).
tensor_size_bytes = None
# Enable different block lengths for different layers when MLA is used.
# Enable different block lengths for different layers *only* when MLA is used.
# This is not used for SSM layers, which use the counterpart `mamba_ssm_size`.
self.block_len_per_layer = list[int]()
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = (
cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches]
# NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to
# that of FI, with block laid out as in `get_backend_aware_kv_block_len`.
# However, physical page_size may differ when kernel requires a specific
# block size. This leads to SSM and FA layers having different num_blocks.
# `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this.
layer_spec = self._layer_specs[layer_name]
if isinstance(layer_spec, UniformTypeKVCacheSpecs):
# MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs
layer_spec = layer_spec.kv_cache_specs[layer_name]
cache_list = self.kv_topo.get_transfer_cache_regions(
cache_or_caches, layer_spec
)
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
# the page_size assuming constant `self._logical_num_blocks`.
physical_page_size = (
layer_spec.page_size_bytes
if isinstance(layer_spec, MambaSpec)
else layer_spec.page_size_bytes
// self._physical_blocks_per_logical_kv_block
)
# For when registering multiple tensors eg K/V in separate regions.
physical_page_size = physical_page_size // len(cache_list)
if self.kv_topo._cross_layers_blocks:
# When cross-layers blocks are used, multiply by number of layers
physical_page_size = physical_page_size * len(
self.kv_cache_config.kv_cache_tensors
)
num_blocks = (
self._logical_num_blocks
if isinstance(layer_spec, MambaSpec)
else self.num_blocks
)
# `page_size` accounts for physical blocks, st KVCache is always
# [`num_blocks` * `page_size`]
curr_tensor_size_bytes = num_blocks * physical_page_size
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
# TODO (NickLucche) we could eventually unify how we handle FA/FI regions,
# registering a single tensor for both K/V and splitting logically like FI.
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
@@ -1494,27 +1590,27 @@ class NixlConnectorWorker:
# across groups. This results in skipping all tensors but the ones
# pointed to by group0. Also, generally we will have more blocks
# per tensor but fewer regions.
logger.debug("Skipping %s because it's already seen", layer_name)
continue
logger.debug(
"Registering layer %s with cache shape: %s", layer_name, cache.shape
)
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()
# Only record non-Mamba page sizes.
if isinstance(layer_spec, MambaSpec):
self.block_len_per_layer.append(
physical_page_size // self._physical_blocks_per_logical_kv_block
)
else:
self.block_len_per_layer.append(physical_page_size)
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
assert cache.shape[0] == self.num_blocks, (
assert cache.shape[0] == num_blocks, (
"All kv cache tensors must have the same number of blocks"
)
self.block_len_per_layer.append(
curr_tensor_size_bytes // self.num_blocks
)
if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP
# Different kv cache shape is not supported by HeteroTP.
# This must also hold true for Mamba-like models.
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
@@ -1533,6 +1629,21 @@ class NixlConnectorWorker:
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
if self.kv_topo.is_kv_layout_blocks_first:
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence we double the number
# of 'virtual' regions here and halve `block_len` below.
# Similarly for Mamba layers, we register SSM+Conv as a single region and
# then duplicate it logically to be able to index SSM/Conv separately.
self.num_regions *= 2
# TODO (NickLucche) Adapt to different descs views (engine_id->tp_rank) to
# support heterogeneous TP.
self.num_descs = self.num_regions * self.num_blocks
descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
@@ -1542,17 +1653,21 @@ class NixlConnectorWorker:
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self.kv_topo.is_kv_layout_blocks_first:
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence we double the number
# of 'virtual' regions here and halve `block_len` below.
self.num_regions *= 2
if self._has_mamba:
logger.info(
"Hybrid SSM registration: num_blocks=%s, "
"logical_num_blocks=%s, ratio=%s, num_regions=%s, "
"num_descs=%s, mamba_ssm_size=%s, block_len_per_layer=%s",
self.num_blocks,
self._logical_num_blocks,
self._physical_blocks_per_logical_kv_block,
self.num_regions,
self.num_descs,
self._mamba_ssm_size,
set(self.block_len_per_layer),
)
# Register local/src descr for NIXL xfer.
self.seen_base_addresses = seen_base_addresses
self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
self.register_local_xfer_handler(self.block_size)
)
@@ -1569,6 +1684,7 @@ class NixlConnectorWorker:
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
block_size=self.block_size,
ssm_sizes=self._mamba_ssm_size,
)
# Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
@@ -1594,40 +1710,65 @@ class NixlConnectorWorker:
data copy correctness.
"""
assert self.kv_topo is not None
kv_topo = self.kv_topo
block_size_ratio = self.block_size // block_size
blocks_data = []
for i, base_addr in enumerate(self.seen_base_addresses):
# The new block_len is using prefill block_len;
# and num_blocks is multiple with N
kv_block_len = (
self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio
)
block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio
num_blocks = self.num_blocks * block_size_ratio
for block_id in range(num_blocks):
block_offset = block_id * block_len_per_layer
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id))
blocks_data: list[tuple[int, int, int]] = []
local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank]
if self.kv_topo.is_kv_layout_blocks_first:
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool):
for i, base_addr in enumerate(local_base_addresses):
# The new block_len is using prefill block_len;
# and num_blocks is multiple with N
kv_block_len = (
self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=True, mamba_view=mamba
)
// block_size_ratio
)
# Jump one page_size, but ssm page_size may be bigger when kernel
# locks block size to a specific value.
block_len_per_layer = (
self.block_len_per_layer[i]
// block_size_ratio
* (1 if not mamba else self._physical_blocks_per_logical_kv_block)
)
num_blocks = self._logical_num_blocks if mamba else self.num_blocks
num_blocks = num_blocks * block_size_ratio
for block_id in range(num_blocks):
block_offset = block_id * block_len_per_layer
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.device_id))
logger.debug(
"Created %s blocks for src engine %s and rank %s on device id %s",
len(blocks_data),
self.engine_id,
self.tp_rank,
self.device_id,
)
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id))
if kv_topo.is_kv_layout_blocks_first:
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=mamba
)
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(num_blocks):
block_offset = block_id * block_len_per_layer
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, second_split, self.device_id))
logger.debug(
"Created %s blocks for src engine %s and rank %s on device id %s",
len(blocks_data),
self.engine_id,
self.tp_rank,
self.device_id,
)
register_blocks(blocks_data, mamba=False)
if self._has_mamba:
assert self.num_descs == len(blocks_data)
logger.debug(
"Registering additional %s local Mamba blocks", len(blocks_data)
)
register_blocks(blocks_data, mamba=True)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
@@ -1708,7 +1849,8 @@ class NixlConnectorWorker:
# local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
kv_topo = self.kv_topo
block_size_ratio = kv_topo.block_size_ratio_from_engine_id(engine_id)
if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
@@ -1768,48 +1910,86 @@ class NixlConnectorWorker:
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
# Read our whole local region size from remote.
local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1:
# using remote kv_block_len as transfer unit
local_block_len = remote_kv_block_len
def register_remote_blocks(
blocks_data: list[tuple[int, int, int]], mamba: bool
):
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
# Read our whole local region size from remote.
local_block_len = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=True, mamba_view=mamba
)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1:
# using remote kv_block_len as transfer unit
local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len = local_block_len // (-tp_ratio)
rank_offset = (
self.tp_rank % tp_ratio * remote_kv_block_len
if indexes_into_remote
else 0
)
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_lens[i]
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id))
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len = local_block_len // (-tp_ratio)
rank_offset = (
self.tp_rank % tp_ratio * remote_kv_block_len
if indexes_into_remote
else 0
)
if self.kv_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_lens[i]
# Assume same num_blocks for mamba and fa
num_blocks = (
nixl_agent_meta.num_blocks
if not mamba
else nixl_agent_meta.num_blocks
// self._physical_blocks_per_logical_kv_block
)
page_size = nixl_agent_meta.block_lens[i] * (
1 if not mamba else self._physical_blocks_per_logical_kv_block
)
for block_id in range(num_blocks):
block_offset = block_id * page_size
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
# (addr, len, device id)
blocks_data.append(
(v_addr, local_block_len, nixl_agent_meta.device_id)
(addr, local_block_len, nixl_agent_meta.device_id)
)
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
len(blocks_data),
engine_id,
remote_tp_rank,
self.tp_rank,
)
if kv_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=mamba
)
# Apply the same scaling as local_block_len above for when we read
# a chunk of local V from `tp_ratio` separate remote workers.
if tp_ratio < 0 and not self.use_mla:
second_split = second_split // (-tp_ratio)
for block_id in range(num_blocks):
block_offset = block_id * page_size
addr = base_addr + block_offset + rank_offset
# Hop over the first split of remote page: either K or Conv.
if mamba:
v_addr = addr + nixl_agent_meta.ssm_sizes[0]
else:
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append(
(v_addr, second_split, nixl_agent_meta.device_id)
)
logger.debug(
"Created %s blocks for dst engine %s"
" with remote rank %s and local rank %s",
len(blocks_data),
engine_id,
remote_tp_rank,
self.tp_rank,
)
register_remote_blocks(blocks_data, mamba=False)
if self._has_mamba:
# Create extra descs for the Mamba "view" of the same KV cache tensors.
logger.debug(
"Registering additional %s remote Mamba blocks", len(blocks_data)
)
register_remote_blocks(blocks_data, mamba=True)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
@@ -1849,6 +2029,9 @@ class NixlConnectorWorker:
assert block_size_ratio == 1, (
"HMA does not support different remote block size yet"
)
# Mamba additional constraints
if self._has_mamba:
assert tp_ratio == 1, "Mamba does not support heterogeneous TP yet"
kv_cache_layout = (
self.kv_cache_layout
@@ -2495,6 +2678,7 @@ class NixlConnectorWorker:
A single flattened array is returned for all groups anyway.
"""
region_ids = np.arange(self.num_regions)
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
# layers from different groups share the same kv tensor.
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
@@ -2505,11 +2689,33 @@ class NixlConnectorWorker:
if block_size_ratio is not None:
num_blocks = int(num_blocks * block_size_ratio)
# Compute the desc ids for each block.
# Compute desc ids per group using the right stride: FA descs have
# num_blocks entries per region (kernel granularity), SSM descs have
# logical_blocks entries per region (no kernel splitting).
region_ids = region_ids[:, None]
block_ids = np.concatenate(block_ids)[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
if not self._has_mamba:
block_ids = np.concatenate(block_ids)[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
else:
# NOTE (NickLucche) SSM and Attention blocks regions can be exchanged
# arbitrarily by manager. Therefore, descs are duplicated for SSM and
# Attention like so:
# desc_handle->[descs_fa (all regions) | descs_ssm (all regions)].
# This is like having two "low-level views" of the same storage.
# `num_fa_descs` offset must be computed per-engine since P and D can
# have different num_blocks (and thus different FA descs counts).
ratio = self._physical_blocks_per_logical_kv_block
# SSM may register fewer num_blocks than FA
logical_blocks = num_blocks // ratio
num_fa_descs = self.num_regions * num_blocks
all_descs = []
for i, group in enumerate(block_ids):
stride = logical_blocks if self._is_mamba_group[i] else num_blocks
group_arr = np.asarray(group)[None, :]
offset = num_fa_descs if self._is_mamba_group[i] else 0
all_descs.append((region_ids * stride + group_arr + offset).flatten())
return np.concatenate(all_descs)
def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds:
"""
@@ -2523,16 +2729,22 @@ class NixlConnectorWorker:
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
# Mamba blocks have no logical<>physical discrepancy
group_specs = self.kv_cache_config.kv_cache_groups
return [
BlockTable.map_to_kernel_blocks(
np.array(group),
self._physical_blocks_per_logical_kv_block,
block_arange,
).tolist()
for group in block_ids
if not isinstance(group_specs[i].kv_cache_spec, MambaSpec)
else group
for i, group in enumerate(block_ids)
]
def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
def get_backend_aware_kv_block_len(
self, layer_idx: int, first_split: bool = True, mamba_view: bool = False
) -> int:
"""
Get the block length for one K/V element (K and V have the same size).
@@ -2540,11 +2752,38 @@ class NixlConnectorWorker:
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
Similarly, for SSM-based models, state and conv are interleaved, but crucially
the their size differs.
Reference diagram:
KVCacheTensor (Shared)
/ \
/ \
/ \
Attention (FlashInfer) View Mamba View
| |
| |
+-------------------+ +-------------------+
| KVCacheTensor | | KVCacheTensor |
| | | |
|<----- page ------>| |<----- page ------->|
| size | | size |
| Key 0 | Val 0 | |Conv 0 | SSM 0 |
| Key 1 | Val 1 | |Conv 1 | SSM 1 |
| ... | ... | | ... | ... |
| Key N-2 | Val N-2 | |Conv N-2| SSM N-2 |
| Key N-1 | Val N-1 | |Conv N-1| SSM N-1 |
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
"""
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part).
block_len = self.block_len_per_layer[layer_idx] // 2
if mamba_view:
# NOTE (NickLucche) Mamba Opt: this is already skipping the padding so
# we're only transferring the minimum required bytes.
block_len = self._mamba_ssm_size[not first_split]
else:
block_len = self.block_len_per_layer[layer_idx] // 2
else:
block_len = self.block_len_per_layer[layer_idx]
return block_len