diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index fe79a99fc..b0794bfa3 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -19,9 +19,9 @@ dp_ep_configs=( "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) ) hybrid_ssm_configs=( - "ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code" + "VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code" # TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models. - "ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling" + "VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling" ) sw_attn_configs=( "ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 898f8e4b3..adb0acae1 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm(): worker._has_mamba = True worker._is_mamba_group = [False, True] worker._physical_blocks_per_logical_kv_block = 1 + worker._mamba_phys_ratio = {engine_id: 1} + worker.block_len_per_layer = [100] # num_descs = num_regions * num_blocks (no blocks_first doubling) worker.num_descs = 2 * num_blocks @@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm(): # FA group: stride=num_blocks=100, offset=0 # region0: [3, 5], region1: [103, 105] # SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1), - # offset=num_descs=200 - # region0: [201, 202], region1: [301, 302] - expected = [3, 5, 103, 105, 201, 202, 301, 302] + # offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm) + # region0: [201, 202], region1: [301, 302], + # region2: [401, 402], region3: [501, 502] + expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502] assert list(result) == expected, f"Expected {expected}, got {list(result)}" @@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch(): worker._has_mamba = True worker._is_mamba_group = [False, True] worker._physical_blocks_per_logical_kv_block = ratio + worker._mamba_phys_ratio = {engine_id: ratio} + worker.block_len_per_layer = [100] worker.num_descs = 2 * num_blocks # 800 fa_blocks = [3, 7] # kernel-level block IDs @@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch(): # FA group: stride=num_blocks=400, offset=0 # region0: [3, 7], region1: [403, 407] - # SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800 - # region0: [801, 802], region1: [901, 902] - expected = [3, 7, 403, 407, 801, 802, 901, 902] + # SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800, + # 4 regions per Mamba layer (x, B, C, ssm) + # region0: [801, 802], region1: [901, 902], + # region2: [1001, 1002], region3: [1101, 1102] + expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102] assert list(result) == expected, f"Expected {expected}, got {list(result)}" @@ -418,3 +425,29 @@ def test_has_mamba_init( ) assert scheduler._has_mamba is expected_has_mamba assert scheduler._is_hma_required is expected_is_hma + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "ssm_sizes,block_len,expected_ratio", + [ + # Nemotron 30B TP=1: ceil((36864 + 2097152) / 8192) = 261 + ((36864, 2097152), 8192, 261), + # Nemotron 30B TP=2: ceil((18432 + 1048576) / 4096) = 261 + ((18432, 1048576), 4096, 261), + # Nemotron 30B TP=4: ceil((9216 + 524288) / 4096) = 131 + ((9216, 524288), 4096, 131), + ], +) +def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio): + """Verify that compute_mamba_phys_ratio is TP-dependent. + + With dimension-sharded Mamba state, the ratio differs across TP sizes + (e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why + _mamba_phys_ratio must be stored per-engine. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( + compute_mamba_phys_ratio, + ) + + assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 72980a85a..8e66fce4c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -5,7 +5,7 @@ KV cache helper for store. """ from collections.abc import Iterator -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, cast import torch @@ -516,6 +516,338 @@ class TpKVTopology: return cache if self.split_k_and_v else [cache] +# ---- Mamba-HMA hetero-TP transfer config ---- +# +# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be +# replicated across P ranks (when P_TP > num_kv_heads), but Mamba +# conv/SSM state is almost always uniquely sharded per P rank. So the +# number of P ranks D must read from can differ between FA and Mamba, +# and they must be handled separately. + + +def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: + """Physical KV head range stored in a rank's KV cache tensor. + + When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. + When ``tp_size > num_heads``: 1 physical head per rank. Heads are + distributed **contiguously** (matching vLLM's GQA weight partitioning): + consecutive ranks share a head before moving to the next one. + """ + if tp_size <= num_heads: + assert num_heads % tp_size == 0 + per_rank = num_heads // tp_size + return range(rank * per_rank, (rank + 1) * per_rank) + else: + h = rank * num_heads // tp_size + return range(h, h + 1) + + +def _range_overlap(a: range, b: range) -> range: + start = max(a.start, b.start) + stop = min(a.stop, b.stop) + return range(start, max(start, stop)) + + +@dataclass +class HeteroTPTransferConfig: + """Precomputed transfer plan for one (D rank, P engine) pair. + + Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models + where FA and mamba require different splitting factors. Could be extended + to other model types that need non-uniform hetero-TP transfer sizing. + + All descriptor sizes are computed here. The guarantee is: + local_entry_size == remote_entry_size (for NIXL) + + Attributes that start with ``fa_`` concern FlashAttention KV cache. + Attributes that start with ``mamba_`` concern Mamba conv/SSM state. + """ + + # ---- Input parameters (from handshake) ---- + tp_ratio: int + K: int # total_num_kv_heads (before TP sharding) + d_tp: int # D engine's tensor_parallel_size + p_tp: int # P engine's tensor_parallel_size + d_rank: int # this D worker's TP rank + use_mla: bool + + # Per-layer block lengths (bytes, K+V combined for blocks_first). + # Uniform across layers for current models. + d_block_len: int # D's block_len_per_layer (representative) + p_block_len: int # P's block_len_per_layer (from handshake) + is_blocks_first: bool # kv_topo.is_kv_layout_blocks_first + + # ---- Derived: computed in __post_init__ ---- + # + # Physical heads per rank (what the KV tensor actually stores) + d_physical_heads: int = field(init=False) + p_physical_heads: int = field(init=False) + + # How many distinct P ranks D needs for FA data + physical_fa_num_reads: int = field(init=False) + + # Which P ranks contribute unique FA heads (ordered by head index) + fa_read_targets: list[int] = field(init=False) + + # All P ranks needed for mamba (always abs_tp for tp_ratio < 0) + mamba_num_reads: int = field(init=False) + + # All P ranks this D rank communicates with (FA ∪ mamba) + transfer_targets: list[int] = field(init=False) + + # FA descriptor entry size (K or V side, for blocks_first layout) + # Guaranteed: fa_entry_size is the SAME for local handle AND remote desc. + fa_entry_size: int = field(init=False) + + # Replication flags + is_d_replicated: bool = field(init=False) + is_p_replicated: bool = field(init=False) + + # Pre-built set for fast lookup + _fa_target_set: frozenset[int] = field(init=False, repr=False) + # Map: P rank → index in fa_read_targets (for head slot offset) + _fa_target_index: dict[int, int] = field(init=False, repr=False) + + def __post_init__(self) -> None: + K = self.K + self.is_d_replicated = self.d_tp > K + self.is_p_replicated = self.p_tp > K + + self.d_physical_heads = max(1, K // self.d_tp) + self.p_physical_heads = max(1, K // self.p_tp) + + abs_tp = -self.tp_ratio if self.tp_ratio < 0 else 1 + + # ---- Mamba range (computed first so FA can prefer ranks in it) ---- + mamba_range: range | None = None + if self.tp_ratio < 0: + mamba_range = range(self.d_rank * abs_tp, (self.d_rank + 1) * abs_tp) + + # ---- FA read targets ---- + if self.use_mla or self.tp_ratio >= 0: + self.physical_fa_num_reads = 1 + self.fa_read_targets = ( + [0] + if self.use_mla + # Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio). + else [ + self.d_rank // self.tp_ratio if self.tp_ratio > 0 else self.d_rank + ] + ) + else: + d_needs = _physical_head_range(self.d_tp, K, self.d_rank) + # When mamba range exists, prefer P ranks within it so that + # FA targets are a subset of mamba transfer_targets (avoids + # orphaned FA targets outside the transfer loop). + search_range = mamba_range if mamba_range is not None else range(self.p_tp) + seen: set[tuple[int, int]] = set() + targets: list[int] = [] + for p in search_range: + p_has = _physical_head_range(self.p_tp, K, p) + ov = _range_overlap(d_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + targets.append(p) + if not targets: + # Fallback: search globally (should not happen in practice) + for p in range(self.p_tp): + p_has = _physical_head_range(self.p_tp, K, p) + ov = _range_overlap(d_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + targets.append(p) + self.fa_read_targets = targets + self.physical_fa_num_reads = len(targets) + + self._fa_target_set = frozenset(self.fa_read_targets) + self._fa_target_index = {r: i for i, r in enumerate(self.fa_read_targets)} + + # ---- Mamba targets ---- + if mamba_range is not None and abs_tp > self.physical_fa_num_reads: + self.mamba_num_reads = abs_tp + self.transfer_targets = list(mamba_range) + else: + self.mamba_num_reads = self.physical_fa_num_reads + self.transfer_targets = list(self.fa_read_targets) + + # ---- FA entry size ---- + # For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V). + # Use min(D, P) because D indexes into P when tp_ratio > 0, + # and P is the natural unit when tp_ratio < 0. + effective_block_len = min(self.d_block_len, self.p_block_len) + if self.is_blocks_first: + self.fa_entry_size = effective_block_len // 2 + else: + self.fa_entry_size = effective_block_len + + self._validate() + + def _validate(self) -> None: + """Cross-check internal consistency.""" + if self.is_d_replicated and self.is_p_replicated and self.tp_ratio > 0: + logger.info( + "Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. " + "Using d_rank // tp_ratio routing with relative head offset.", + self.d_tp, + self.p_tp, + self.K, + ) + + # FA targets must be a subset of transfer_targets + tt_set = set(self.transfer_targets) + for t in self.fa_read_targets: + if t not in tt_set: + logger.error( + "FA target P rank %d is NOT in transfer_targets %s. " + "This will cause missed FA reads!", + t, + self.transfer_targets, + ) + + # For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half + if ( + self.is_blocks_first + and self.tp_ratio < 0 + and self.physical_fa_num_reads > 0 + ): + d_k_half = self.d_block_len // 2 + p_k_half = self.p_block_len // 2 + expected_local = d_k_half // self.physical_fa_num_reads + if expected_local != p_k_half: + logger.warning( + "FA size mismatch: D_K_half=%d / reads=%d = %d, " + "but P_K_half=%d. This may indicate a head count or " + "Mamba-HMA inflation inconsistency.", + d_k_half, + self.physical_fa_num_reads, + expected_local, + p_k_half, + ) + + # ---- Query methods ---- + + def should_skip_fa(self, p_rank: int) -> bool: + """Whether to skip FA groups for this P rank (mamba-only transfer).""" + return p_rank not in self._fa_target_set + + def fa_head_slot(self, p_rank: int) -> int: + """Index into D's FA block for this P rank's head data. + + For P ranks in fa_read_targets, returns 0, 1, ..., reads-1. + For P ranks NOT in fa_read_targets (replicated duplicates), + returns the slot of the matching FA target with the same head. + """ + if p_rank in self._fa_target_index: + return self._fa_target_index[p_rank] + # Duplicate head: find which fa_target has the same physical head + p_head = _physical_head_range(self.p_tp, self.K, p_rank) + for target in self.fa_read_targets: + t_head = _physical_head_range(self.p_tp, self.K, target) + if _range_overlap(p_head, t_head): + return self._fa_target_index[target] + return 0 # fallback + + def fa_rank_offset(self, remote_kv_block_len: int) -> int: + """Byte offset into P's FA block for this D rank. + + When D is replicated (D_TP > K), multiple D ranks share a head. + Computes offset *relative to the target P rank's first head* + so it works regardless of how many heads P has. + When neither side replicates, falls back to tp_rank % tp_ratio. + Returns 0 when D does not index into P's block. + """ + if self.use_mla or self.tp_ratio <= 0: + return 0 + if self.is_d_replicated: + d_head = self.d_rank * self.K // self.d_tp + p_rank = self.fa_read_targets[0] + p_start = p_rank * self.K // self.p_tp + return (d_head - p_start) * remote_kv_block_len + return self.d_rank % self.tp_ratio * remote_kv_block_len + + @property + def needs_split_handles(self) -> bool: + """Whether per-P-rank split handles are needed. + + True when FA and mamba have different read counts, requiring + different splitting factors in the local handle. + """ + return self.tp_ratio < 0 and not self.use_mla and len(self.transfer_targets) > 1 + + def compute_split_handle_data( + self, + src_blocks_data: list[tuple[int, int, int]], + num_fa_descs: int, + abs_tp: int, + ) -> list[list[tuple[int, int, int]]]: + """Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles. + + FA descriptors (indices < num_fa_descs) are sliced by + ``physical_fa_num_reads``; mamba descriptors are sliced uniformly + by ``abs_tp``. + + Returns one list of triples per transfer target. + """ + all_handle_data: list[list[tuple[int, int, int]]] = [] + for p_idx, p_rank in enumerate(self.transfer_targets): + handle_data: list[tuple[int, int, int]] = [] + skip_fa = self.should_skip_fa(p_rank) + fa_slot = self.fa_head_slot(p_rank) if not skip_fa else 0 + + for j, (addr, local_len, tp) in enumerate(src_blocks_data): + if j < num_fa_descs: + assert self.physical_fa_num_reads >= 1 + fa_chunk = local_len // self.physical_fa_num_reads + handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, tp)) + else: + mamba_chunk = local_len // abs_tp + handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, tp)) + all_handle_data.append(handle_data) + return all_handle_data + + def filter_block_ids_for_rank( + self, + remote_rank: int, + local_ids: BlockIds, + remote_ids: BlockIds, + is_mamba_group: list[bool], + ) -> tuple[BlockIds, BlockIds]: + """Zero out FA groups for P ranks outside fa_read_targets. + + Returns (filtered_local_ids, filtered_remote_ids). When the + remote rank carries FA data for this D rank, returns the inputs + unchanged. + """ + if not self.should_skip_fa(remote_rank): + return local_ids, remote_ids + num_groups = len(local_ids) + filtered_local: list[list[int]] = [ + [] if not is_mamba_group[g] else local_ids[g] for g in range(num_groups) + ] + filtered_remote: list[list[int]] = [ + [] if not is_mamba_group[g] else remote_ids[g] for g in range(num_groups) + ] + return filtered_local, filtered_remote + + def describe(self) -> str: + """One-line summary for logging.""" + return ( + f"HeteroTPTransferConfig(" + f"tp_ratio={self.tp_ratio}, K={self.K}, " + f"d_tp={self.d_tp}, p_tp={self.p_tp}, d_rank={self.d_rank}, " + f"physical_fa_reads={self.physical_fa_num_reads}, " + f"mamba_reads={self.mamba_num_reads}, " + f"fa_targets={self.fa_read_targets}, " + f"transfer_targets={self.transfer_targets}, " + f"fa_entry_size={self.fa_entry_size}, " + f"d_block_len={self.d_block_len}, p_block_len={self.p_block_len})" + ) + + def get_current_attn_backends( vllm_config: VllmConfig, layer_names: list[str] | None = None ) -> list[type[AttentionBackend]]: @@ -559,3 +891,50 @@ def get_current_attn_backend( ) -> type[AttentionBackend]: """Get the first attention backend for the given layers.""" return get_current_attn_backends(vllm_config, layer_names)[0] + + +# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig +# into a single engine-agnostic TransferTopology class. +# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data. +# +# @dataclass +# class EngineTransferInfo: +# """Per-remote-engine transfer state, computed at handshake.""" +# p_tp: int +# tp_ratio: int +# p_block_len: int +# block_size: int +# # Mamba-specific (None for non-mamba models) +# fa_read_targets: list[int] | None = None +# transfer_targets: list[int] | None = None +# physical_fa_num_reads: int | None = None +# mamba_num_reads: int | None = None +# fa_entry_size: int | None = None +# +# class TransferTopology: +# """Single source of truth for TP topology + transfer sizing.""" +# # Shared (set once at init, replaces duplicate fields) +# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank +# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp +# total_num_kv_heads: int # == HeteroTP.K +# is_mla: bool # == HeteroTP.use_mla +# is_mamba: bool +# is_blocks_first: bool # == HeteroTP.is_blocks_first +# d_block_len: int +# +# # Per-engine (populated via register_engine() at handshake) +# _engines: dict[EngineId, EngineTransferInfo] +# +# def register_engine(self, engine_id, p_tp, p_block_len, ...): ... +# +# # General (from TpKVTopology) +# def tp_ratio(self, engine_id) -> int: ... +# def target_remote_ranks(self, engine_id) -> list[int]: ... +# def is_kv_replicated(self, engine_id) -> bool: ... +# +# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba) +# def fa_rank_offset(self, engine_id, block_len) -> int: ... +# def physical_fa_num_reads(self, engine_id) -> int: ... +# def transfer_targets(self, engine_id) -> list[int]: ... +# def should_skip_fa(self, engine_id, p_rank) -> bool: ... +# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ... diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0aaf3b6e9..c575043fb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -25,6 +25,7 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( BlockIds, EngineId, + HeteroTPTransferConfig, TpKVTopology, get_current_attn_backend, get_current_attn_backends, @@ -47,12 +48,18 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( PromMetric, PromMetricT, ) +from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( + MambaConvSplitInfo, + compute_mamba_phys_ratio, + derive_mamba_conv_split, +) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.network_utils import make_zmq_path, make_zmq_socket @@ -1038,7 +1045,7 @@ class NixlConnectorWorker: } self.hma_group_size = len(kv_cache_config.kv_cache_tensors) - # Mamba metadata + # ---- Mamba model state (derived from model config) ---- self._is_mamba_group = [ isinstance(group.kv_cache_spec, MambaSpec) for group in kv_cache_config.kv_cache_groups @@ -1065,6 +1072,17 @@ class NixlConnectorWorker: ssm_shape.numel() * ssm_nbytes, ) self._mamba_ssm_size = mamba_ssm_size + # Conv state sub-projection decomposition (None when no Mamba). + # The 3-read transfer requires DS (dim, state_len) conv layout so + # that x/B/C sub-projections are contiguous in memory. + self._conv_decomp: MambaConvSplitInfo | None = None + if self._has_mamba: + assert is_conv_state_dim_first(), ( + "3-read Mamba conv transfer requires DS conv state layout. " + "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" + ) + local_tp = vllm_config.parallel_config.tensor_parallel_size + self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp) # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] @@ -1175,6 +1193,16 @@ class NixlConnectorWorker: self.dst_num_blocks: dict[EngineId, int] = {} self._registered_descs: list[Any] = [] + # ---- Mamba-HMA per-engine state (only used when self._has_mamba) ---- + # Per-engine transfer config (source of truth for FA/mamba sizing). + self._transfer_configs: dict[str, HeteroTPTransferConfig] = {} + # NOTE (ZhanqiuHu): _mamba_phys_ratio MUST be per-engine. + # compute_mamba_phys_ratio = ceil((conv_bytes + ssm_bytes) / block_len) + # where conv/ssm bytes are per-TP-rank (dimension-sharded). With + # heterogeneous TP the per-rank sizes differ, so the ratio differs: + # e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261. + self._mamba_phys_ratio: dict[EngineId, int] = {} + # In progress transfers. # [req_id -> list[handle]] self._recving_metadata: dict[ReqId, ReqMeta] = {} @@ -1701,8 +1729,7 @@ class NixlConnectorWorker: # then duplicate it logically to be able to index SSM/Conv separately. self.num_regions *= 2 - # TODO (NickLucche) Adapt to different descs views (engine_id->tp_rank) to - # support heterogeneous TP. + # Total local FA descriptors (boundary between FA and mamba descs). self.num_descs = self.num_regions * self.num_blocks descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) @@ -1715,6 +1742,9 @@ class NixlConnectorWorker: self.dst_num_blocks[self.engine_id] = self.num_blocks if self._has_mamba: + self._mamba_phys_ratio[self.engine_id] = ( + self._physical_blocks_per_logical_kv_block + ) logger.info( "Hybrid SSM registration: num_blocks=%s, " "logical_num_blocks=%s, ratio=%s, num_regions=%s, " @@ -1755,6 +1785,149 @@ class NixlConnectorWorker: agent_metadata_bytes=encoder.encode(agent_metadata), ) + def _build_mamba_local( + self, + base_addresses: list[int], + block_size_ratio: int, + ) -> list[tuple[int, int, int]]: + """Build 4 desc regions (x, B, C, ssm) per layer for local mamba + blocks, enabling the 3-read transfer with DS conv layout.""" + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " + f"Got block_size_ratio={block_size_ratio}." + ) + assert self._conv_decomp is not None + conv_offsets = self._conv_decomp.local_conv_offsets + conv_size, ssm_size = self._mamba_ssm_size + num_blocks = self._logical_num_blocks * block_size_ratio + phys_ratio = self._physical_blocks_per_logical_kv_block + + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate(base_addresses): + page_stride = self.block_len_per_layer[i] // block_size_ratio * phys_ratio + for off, sz in conv_offsets: + for blk in range(num_blocks): + result.append( + (base_addr + blk * page_stride + off, sz, self.device_id) + ) + # SSM temporal state follows the conv state. + for blk in range(num_blocks): + result.append( + ( + base_addr + blk * page_stride + conv_size, + ssm_size, + self.device_id, + ) + ) + return result + + def _build_fa_remote_for_mamba( + self, + nixl_agent_meta: NixlAgentMetadata, + transfer_cfg: HeteroTPTransferConfig, + block_size_ratio: int, + kv_topo: TpKVTopology, + ) -> list[tuple[int, int, int]]: + """Build remote FA descriptors for mamba models. + + Uses transfer_cfg for GQA-aware FA divisor and head-based rank offset + instead of the standard uniform tp_ratio split. + """ + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " + f"Got block_size_ratio={block_size_ratio}." + ) + # TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA + # hetero-TP logic stabilizes. + tp_ratio = transfer_cfg.tp_ratio + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + local_block_len = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=True, mamba_view=False + ) + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + local_block_len = remote_kv_block_len + + if tp_ratio < 0 and not self.use_mla: + local_block_len = local_block_len // transfer_cfg.physical_fa_num_reads + + rank_offset = transfer_cfg.fa_rank_offset(remote_kv_block_len) + + num_blocks = nixl_agent_meta.num_blocks + page_size = nixl_agent_meta.block_lens[i] + for block_id in range(num_blocks): + block_offset = block_id * page_size + addr = base_addr + block_offset + rank_offset + result.append((addr, local_block_len, nixl_agent_meta.device_id)) + + if kv_topo.is_kv_layout_blocks_first: + second_split = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=False, mamba_view=False + ) + if tp_ratio < 0 and not self.use_mla: + second_split = second_split // transfer_cfg.physical_fa_num_reads + for block_id in range(num_blocks): + block_offset = block_id * page_size + addr = base_addr + block_offset + rank_offset + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + result.append((v_addr, second_split, nixl_agent_meta.device_id)) + return result + + def _build_mamba_remote( + self, + nixl_agent_meta: NixlAgentMetadata, + tp_ratio: int, + ) -> list[tuple[int, int, int]]: + """Build 4 remote desc regions (x, B, C, ssm) per layer for + the 3-read transfer. For hetero-TP, each D rank reads only its + sub-projection slice from the P rank.""" + assert self._conv_decomp is not None + effective_ratio = max(tp_ratio, 1) + # Mamba conv state is always TP-sharded, even when attention KV + # is replicated (num_kv_heads < tp_size). + local_offset = self.tp_rank % effective_ratio + conv_size_remote = nixl_agent_meta.ssm_sizes[0] + + if tp_ratio >= 1: + # D_TP >= P_TP: P page is larger, D reads its slice. + conv_offsets = self._conv_decomp.remote_conv_offsets( + local_offset, effective_ratio + ) + ssm_read_size = self._mamba_ssm_size[1] + else: + # NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages + # are smaller than D's. self._conv_decomp has D-sized dimensions, + # but we need P-sized offsets. Scale down by |tp_ratio|. + abs_ratio = -tp_ratio + xb_p = self._conv_decomp.x_bytes // abs_ratio + bb_p = self._conv_decomp.b_bytes // abs_ratio + conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)] + ssm_read_size = nixl_agent_meta.ssm_sizes[1] + + remote_ratio = self._mamba_phys_ratio[nixl_agent_meta.engine_id] + num_blocks = nixl_agent_meta.num_blocks // remote_ratio + device_id = nixl_agent_meta.device_id + + result: list[tuple[int, int, int]] = [] + # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case + # block lengths vary across layers (e.g. MLA). + for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + page_stride = nixl_agent_meta.block_lens[i] * remote_ratio + for off, sz in conv_offsets: + for blk in range(num_blocks): + result.append((base_addr + blk * page_stride + off, sz, device_id)) + # SSM temporal state is also TP-sharded on the heads dimension. + for blk in range(num_blocks): + ssm_addr = ( + base_addr + + blk * page_stride + + conv_size_remote + + local_offset * ssm_read_size + ) + result.append((ssm_addr, ssm_read_size, device_id)) + return result + def register_local_xfer_handler( self, block_size: int, @@ -1823,13 +1996,22 @@ class NixlConnectorWorker: self.device_id, ) + # NOTE (ZhanqiuHu): mamba=True path in register_blocks is not used + # right now — we use _build_mamba_local instead for the 3-read + # approach. However, we might still need this as a fallback for homogeneous TP. register_blocks(blocks_data, mamba=False) if self._has_mamba: assert self.num_descs == len(blocks_data) - logger.debug( - "Registering additional %s local Mamba blocks", len(blocks_data) + # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read split is + # unnecessary — a single conv desc per block suffices. Consider + # adding a fast path that falls back to the standard 2-region + # registration (register_blocks mamba=True) when no hetero-TP + # remote has been seen. Currently we always register 4 regions + # because local descs are created before knowing the remote TP. + logger.debug("Registering local Mamba descriptors (4 regions/layer)") + blocks_data.extend( + self._build_mamba_local(local_base_addresses, block_size_ratio) ) - register_blocks(blocks_data, mamba=True) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. @@ -1880,6 +2062,9 @@ class NixlConnectorWorker: Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 so that the whole cache is shared by "tp_ratio" D TP workers. + + For Mamba hetero-TP, both tp_ratio > 0 (D_TP > P_TP) and + tp_ratio < 0 (P_TP > D_TP) are supported by the 3-read transfer. """ # noqa: E501 engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery @@ -1915,6 +2100,10 @@ class NixlConnectorWorker: if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + if self._has_mamba: + self._mamba_phys_ratio[engine_id] = compute_mamba_phys_ratio( + nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0] + ) # Keep track of remote agent kv caches base addresses. self.kv_caches_base_addr[engine_id][remote_tp_rank] = ( @@ -1931,6 +2120,21 @@ class NixlConnectorWorker: not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 ) + # Create transfer config (single source of truth for descriptor sizes). + if self._has_mamba and engine_id not in self._transfer_configs: + self._transfer_configs[engine_id] = HeteroTPTransferConfig( + tp_ratio=tp_ratio, + K=kv_topo.total_num_kv_heads, + d_tp=self.world_size, + p_tp=remote_tp_size, + d_rank=self.tp_rank, + use_mla=self.use_mla, + d_block_len=self.block_len_per_layer[0], + p_block_len=nixl_agent_meta.block_lens[0], + is_blocks_first=kv_topo.is_kv_layout_blocks_first, + ) + logger.info("Created %s", self._transfer_configs[engine_id].describe()) + logger.debug( "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", engine_id, @@ -1947,21 +2151,48 @@ class NixlConnectorWorker: # Remote tp_size > local tp_size: read from multiple remote ranks. # Logically "split" own regions into |tp_ratio| chunks. Mind that # we only do this once per remote tp_size (replica-friendly). + abs_tp = -tp_ratio self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - for i in range(-tp_ratio): - blocks_data = [] - for memory_region in self.src_blocks_data: - addr, local_block_len, own_tp_rank = memory_region - # Computing block len layer by layer allows for different - # block sizes to be used. - remote_block_len = local_block_len // (-tp_ratio) - addr = addr + i * remote_block_len - blocks_data.append((addr, remote_block_len, own_tp_rank)) - descs = self.nixl_wrapper.get_xfer_descs( - blocks_data, self.nixl_memory_type - ) - handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) + + if self._has_mamba: + transfer_cfg = self._transfer_configs.get(engine_id) + assert transfer_cfg is not None + if transfer_cfg.needs_split_handles: + # Mamba-HMA: FA and Mamba use different split factors. + for handle_data in transfer_cfg.compute_split_handle_data( + self.src_blocks_data, self.num_descs, abs_tp + ): + descs = self.nixl_wrapper.get_xfer_descs( + handle_data, self.nixl_memory_type + ) + handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs + ) + self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) + + logger.info( + "Mamba-HMA split handles: targets=%s, fa_reads=%s, " + "fa_entry=%s, mamba_reads=%s, num_descs=%s", + transfer_cfg.transfer_targets, + transfer_cfg.physical_fa_num_reads, + transfer_cfg.fa_entry_size, + transfer_cfg.mamba_num_reads, + self.num_descs, + ) + else: + # Original path: uniform divide by abs_tp (non-Mamba-HMA). + for i in range(abs_tp): + blocks_data = [] + for memory_region in self.src_blocks_data: + addr, local_block_len, own_tp_rank = memory_region + remote_block_len = local_block_len // abs_tp + addr = addr + i * remote_block_len + blocks_data.append((addr, remote_block_len, own_tp_rank)) + descs = self.nixl_wrapper.get_xfer_descs( + blocks_data, self.nixl_memory_type + ) + handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) ### Register remote agent memory regions blocks_data = [] @@ -2044,13 +2275,33 @@ class NixlConnectorWorker: self.tp_rank, ) - register_remote_blocks(blocks_data, mamba=False) if self._has_mamba: - # Create extra descs for the Mamba "view" of the same KV cache tensors. + # Mamba-HMA: separate FA registration with GQA-aware sizing, + # plus mamba 3-read registration for the Mamba "view" of the + # same KV cache tensors. logger.debug( - "Registering additional %s remote Mamba blocks", len(blocks_data) + "Registering remote Mamba blocks for engine %s rank %s", + engine_id, + remote_tp_rank, ) - register_remote_blocks(blocks_data, mamba=True) + transfer_cfg = self._transfer_configs.get(engine_id) + assert transfer_cfg is not None + blocks_data.extend( + self._build_fa_remote_for_mamba( + nixl_agent_meta, + transfer_cfg, + block_size_ratio, + kv_topo, + ) + ) + blocks_data.extend( + self._build_mamba_remote( + nixl_agent_meta, + tp_ratio, + ) + ) + else: + register_remote_blocks(blocks_data, mamba=False) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) @@ -2083,17 +2334,17 @@ class NixlConnectorWorker: block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( remote_engine_id ) - # Num kv_heads > tp_size and P TP > D TP case, not supported - assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)) + # num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba. + # Mamba models can have replicated FA KV with tp_ratio < 0. + if not self._has_mamba: + assert not ( + tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id) + ) if self._is_hma_required: assert block_size_ratio == 1, ( "HMA does not support different remote block size yet" ) - # Mamba additional constraints - if self._has_mamba: - assert tp_ratio == 1, "Mamba does not support heterogeneous TP yet" - kv_cache_layout = ( self.kv_cache_layout if not self.use_host_buffer @@ -2138,11 +2389,14 @@ class NixlConnectorWorker: remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. - for i in range(len(self.block_len_per_layer)): - assert ( - self.block_len_per_layer[i] // block_size_ratio - == nixl_agent_meta.block_lens[i] - ), "KV cache sizes must match between P and D when replicated" + # TODO (ZhanqiuHu): For mamba models, validate FA and mamba + # block_lens separately. + if not self._has_mamba: + for i in range(len(self.block_len_per_layer)): + assert ( + self.block_len_per_layer[i] // block_size_ratio + == nixl_agent_meta.block_lens[i] + ), "KV cache sizes must match between P and D when replicated" else: # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: @@ -2150,25 +2404,31 @@ class NixlConnectorWorker: "All remote layers must have the same block size" ) - if tp_ratio > 0: - # Remote tp is smaller: remote block_len size is bigger - assert ( - remote_block_len - == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio - ), ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, page_size, head_dim] and same dtype." - ) # noqa: E501 - else: - assert block_size_ratio == 1, ( - "Different local/remote block sizes are not supported when" - " P TP > D TP." - ) - # Remote tp is bigger: remote block_len size is smaller - assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads/tp_ratio, page_size, head_dim] and same dtype." - ) # noqa: E501 + # HMA hybrid models (mamba+attention) pad block_len to + # max(attn_page, mamba_page), so the linear tp_ratio scaling + # assumption only holds for pure-attention models. + if not self._has_mamba: + if tp_ratio > 0: + assert ( + remote_block_len + == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio + ), ( + "Remote P worker KV layer cache must be of shape [2, N," + " local_kv_heads*tp_ratio, page_size, head_dim] and " + "same dtype." + ) + else: + assert block_size_ratio == 1, ( + "Different local/remote block sizes are not supported" + " when P TP > D TP." + ) + assert remote_block_len == self.block_len_per_layer[0] // ( + -tp_ratio + ), ( + "Remote P worker KV layer cache must be of shape [2, N," + " local_kv_heads/tp_ratio, page_size, head_dim] and " + "same dtype." + ) # TP workers that handhshake with same remote have same #blocks. assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks @@ -2471,9 +2731,8 @@ class NixlConnectorWorker: meta.local_block_ids ) assert meta.remote is not None - meta.remote.block_ids = self._logical_to_kernel_block_ids( - meta.remote.block_ids - ) + # Remote block IDs are kept logical here; expanded in + # _read_blocks_for_req using the remote engine's phys ratio. remote_engine_id = meta.remote.engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " @@ -2525,6 +2784,13 @@ class NixlConnectorWorker: meta.remote.engine_id ) tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id) + + if self._has_mamba: + # Expand remote logical → kernel block IDs. + meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( + meta.remote.block_ids, + self._mamba_phys_ratio[meta.remote.engine_id], + ) # D may have to perform multiple reads from different remote ranks. for i, remote_rank in enumerate(remote_ranks): if self.use_mla and tp_ratio < 0 and i > 0: @@ -2558,12 +2824,26 @@ class NixlConnectorWorker: remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ remote_rank ] + + local_ids: BlockIds = meta.local_physical_block_ids + remote_ids: BlockIds = meta.remote.block_ids + if self._has_mamba: + # Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets. + transfer_cfg = self._transfer_configs.get(meta.remote.engine_id) + assert transfer_cfg is not None + local_ids, remote_ids = transfer_cfg.filter_block_ids_for_rank( + remote_rank, + local_ids, + remote_ids, + self._is_mamba_group, + ) + self._read_blocks( request_id=req_id, dst_engine_id=meta.remote.engine_id, remote_request_id=meta.remote.request_id, - local_block_ids=meta.local_physical_block_ids, - remote_block_ids=meta.remote.block_ids, + local_block_ids=local_ids, + remote_block_ids=remote_ids, remote_rank=remote_rank, local_xfer_side_handle=local_xfer_side_handle, remote_xfer_side_handle=remote_xfer_side_handle, @@ -2663,9 +2943,12 @@ class NixlConnectorWorker: for i, remote_group in enumerate(remote_block_ids): num_remote_blocks = len(remote_group) num_local_blocks = len(local_block_ids[i]) - assert num_local_blocks <= num_remote_blocks + if not self._is_mamba_group[i]: + assert num_local_blocks <= num_remote_blocks # Partial prefix cache hit: just read uncomputed blocks. - if num_local_blocks < num_remote_blocks: + # Skip mamba groups — their blocks represent full state (conv+ssm), + # not per-token data, so trimming would corrupt the transfer. + if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: remote_block_ids[i] = remote_group[-num_local_blocks:] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from @@ -2781,16 +3064,22 @@ class NixlConnectorWorker: # This is like having two "low-level views" of the same storage. # `num_fa_descs` offset must be computed per-engine since P and D can # have different num_blocks (and thus different FA descs counts). - ratio = self._physical_blocks_per_logical_kv_block - # SSM may register fewer num_blocks than FA + ratio = self._mamba_phys_ratio[engine_id] logical_blocks = num_blocks // ratio num_fa_descs = self.num_regions * num_blocks + # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). + mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None] all_descs = [] for i, group in enumerate(block_ids): - stride = logical_blocks if self._is_mamba_group[i] else num_blocks group_arr = np.asarray(group)[None, :] - offset = num_fa_descs if self._is_mamba_group[i] else 0 - all_descs.append((region_ids * stride + group_arr + offset).flatten()) + if self._is_mamba_group[i]: + all_descs.append( + ( + mamba_region_ids * logical_blocks + group_arr + num_fa_descs + ).flatten() + ) + else: + all_descs.append((region_ids * num_blocks + group_arr).flatten()) return np.concatenate(all_descs) def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: @@ -2818,6 +3107,36 @@ class NixlConnectorWorker: for i, group in enumerate(block_ids) ] + def _logical_to_remote_kernel_block_ids( + self, block_ids: BlockIds, remote_ratio: int + ) -> BlockIds: + """Map logical block IDs to physical kernel block IDs on the remote. + + Args: + block_ids: per-group lists of logical block IDs. + remote_ratio: remote engine's physical blocks per logical block. + + Returns: + Same structure with FA groups expanded (each logical block L + becomes kernel blocks [L*remote_ratio .. L*remote_ratio + + local_ratio - 1]). Mamba groups are passed through unchanged. + """ + local_ratio = self._physical_blocks_per_logical_kv_block + if remote_ratio == 1: + return block_ids + local_arange = np.arange(local_ratio).reshape(1, -1) + group_specs = self.kv_cache_config.kv_cache_groups + result: list[list[int]] = [] + for i, group in enumerate(block_ids): + if not isinstance(group_specs[i].kv_cache_spec, MambaSpec): + arr = np.array(group).reshape(-1, 1) + expanded = (arr * remote_ratio + local_arange).flatten() + result.append(expanded.tolist()) + else: + # Mamba blocks are 1:1 logical-to-physical (no expansion). + result.append(group) + return result + def get_backend_aware_kv_block_len( self, layer_idx: int, first_split: bool = True, mamba_view: bool = False ) -> int: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py new file mode 100644 index 000000000..6d65e006e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Mamba conv-state sub-projection decomposition for the 3-read transfer. + +With DS conv state layout (dim, state_len), x/B/C sub-projections are +contiguous in memory. Each D rank reads its x, B, C slices via 3 +separate RDMA transfers — no P-side permutation needed. +""" + +import math +from dataclasses import dataclass + +import torch + +from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first +from vllm.v1.kv_cache_interface import MambaSpec + + +@dataclass(frozen=True) +class MambaConvSplitInfo: + """Per-rank byte sizes of x, B, C sub-projections in the Mamba conv state. + + Used by both P and D sides for NIXL descriptor registration. + All fields are LOCAL to this engine's TP (already divided by TP size). + + DS memory layout within one page (contiguous in memory): + |--- x (x_local * conv_rows) ---|- B (b_local * conv_rows) -|- C -| + """ + + conv_rows: int # conv_kernel - 1 (typically 3) + x_local: int # intermediate_size / TP (columns for x) + b_local: int # groups_ss / TP (columns for B; C is same size) + conv_dtype_size: int # bytes per element (e.g. 2 for float16) + + @property + def conv_dim_local(self) -> int: + """Total conv columns per rank: x + B + C.""" + return self.x_local + 2 * self.b_local + + @property + def x_bytes(self) -> int: + """Byte size of the x sub-projection for one rank.""" + return self.x_local * self.conv_rows * self.conv_dtype_size + + @property + def b_bytes(self) -> int: + """Byte size of the B (or C) sub-projection for one rank.""" + return self.b_local * self.conv_rows * self.conv_dtype_size + + @property + def local_conv_offsets(self) -> list[tuple[int, int]]: + """(byte_offset, byte_size) of x, B, C within this engine's page. + + Used by both P and D for local descriptor registration. + """ + xb = self.x_bytes + bb = self.b_bytes + return [(0, xb), (xb, bb), (xb + bb, bb)] + + def remote_conv_offsets( + self, local_rank_offset: int, tp_ratio: int + ) -> list[tuple[int, int]]: + """(byte_offset, byte_size) of this D rank's x, B, C slice within + one P page. + + Used by D side only, during remote descriptor registration. + + Args: + local_rank_offset: which slice this D rank reads. + tp_ratio > 0: tp_rank % tp_ratio (selects slice of P's page). + tp_ratio < 0: always 0 (read P's full page). + tp_ratio: effective ratio (>= 1 when D_TP > P_TP, 1 when + P_TP > D_TP since each P rank is read in full). + """ + xb = self.x_bytes + bb = self.b_bytes + xr = xb * tp_ratio # full remote x section in bytes + br = bb * tp_ratio # full remote B section in bytes + return [ + (local_rank_offset * xb, xb), + (xr + local_rank_offset * bb, bb), + (xr + br + local_rank_offset * bb, bb), + ] + + +def derive_mamba_conv_split( + mamba_spec: MambaSpec, + local_tp: int, +) -> MambaConvSplitInfo: + """Derive per-rank x/B/C byte sizes from a MambaSpec. + + Called once at init on both P and D. Decomposes the conv dimension + (= intermediate_size + 2 * groups_ss) into its x, B, C parts. + + Args: + mamba_spec: MambaSpec whose shapes are: + shapes[0] = conv state: (conv_dim_local, conv_rows) in DS layout. + shapes[1] = SSM temporal: (local_num_heads, head_dim). + local_tp: this engine's tensor-parallel size. + + Returns: + MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, and + conv_dtype_size. + """ + if mamba_spec.mamba_type != "mamba2": + raise NotImplementedError( + f"3-read conv transfer only supports Mamba2 models, " + f"got mamba_type={mamba_spec.mamba_type!r}. " + f"Mamba1 SSM temporal shape is (intermediate_size // tp, state_size) " + f"which cannot be used to reconstruct intermediate_size." + ) + + conv_shape = mamba_spec.shapes[0] + assert len(conv_shape) == 2, f"Expected 2D conv state shape, got {conv_shape}" + + # NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted + # in nixl_connector __init__. Use it directly instead of heuristic detection. + assert is_conv_state_dim_first(), "3-read requires DS conv state layout" + local_conv_dim = conv_shape[0] # DS: (conv_dim_local, conv_rows) + conv_rows = conv_shape[1] + + # NOTE (ZhanqiuHu): intermediate_size (= global x dim) is not stored + # in MambaSpec, so we reconstruct it from the SSM temporal state shape: + # shapes[1] = (local_num_heads, head_dim), already divided by TP. + head_dim = mamba_spec.shapes[1][1] + local_num_heads = mamba_spec.shapes[1][0] + intermediate_size = local_num_heads * local_tp * head_dim + + # NOTE (ZhanqiuHu): global conv dim = intermediate_size + 2 * groups_ss, + # where groups_ss is the B (= C) dimension. B and C are always the same + # size, so we recover groups_ss from the remainder after subtracting x. + remainder = local_conv_dim * local_tp - intermediate_size + assert remainder > 0 and remainder % 2 == 0, ( + f"Conv dim ({local_conv_dim}*tp={local_tp}) doesn't decompose into " + f"intermediate_size={intermediate_size} + 2*groups_ss. " + f"remainder={remainder}" + ) + groups_ss = remainder // 2 + + conv_dtype_size = torch.tensor( + [], + dtype=mamba_spec.dtypes[0], # type: ignore[misc] + ).element_size() + + # Divide by TP to get per-rank column counts. + return MambaConvSplitInfo( + conv_rows=conv_rows, + x_local=intermediate_size // local_tp, + b_local=groups_ss // local_tp, + conv_dtype_size=conv_dtype_size, + ) + + +def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int: + """Derive _physical_blocks_per_logical_kv_block from remote metadata. + + The remote engine's ratio is not sent directly in the handshake, so we + reconstruct it: total mamba state per logical block / block_len. + + Args: + ssm_sizes: (conv_state_bytes, ssm_state_bytes) from NixlAgentMetadata. + block_len: the engine's block_len in bytes (from block_lens[0]). + """ + return math.ceil((ssm_sizes[0] + ssm_sizes[1]) / block_len)