[NIXL][Mamba][3/N] Heterogeneous TP: 3-read conv state transfer (#37635)

This commit is contained in:
zhanqiuhu
2026-04-06 13:07:02 -04:00
committed by GitHub
parent 93bada494f
commit bfdc0a3a99
5 changed files with 970 additions and 75 deletions

View File

@@ -19,9 +19,9 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
hybrid_ssm_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
sw_attn_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"

View File

@@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
worker._mamba_phys_ratio = {engine_id: 1}
worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
@@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm():
# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# region0: [201, 202], region1: [301, 302]
expected = [3, 5, 103, 105, 201, 202, 301, 302]
# offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm)
# region0: [201, 202], region1: [301, 302],
# region2: [401, 402], region3: [501, 502]
expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
@@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker._mamba_phys_ratio = {engine_id: ratio}
worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800
fa_blocks = [3, 7] # kernel-level block IDs
@@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch():
# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# region0: [801, 802], region1: [901, 902]
expected = [3, 7, 403, 407, 801, 802, 901, 902]
# SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800,
# 4 regions per Mamba layer (x, B, C, ssm)
# region0: [801, 802], region1: [901, 902],
# region2: [1001, 1002], region3: [1101, 1102]
expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
@@ -418,3 +425,29 @@ def test_has_mamba_init(
)
assert scheduler._has_mamba is expected_has_mamba
assert scheduler._is_hma_required is expected_is_hma
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"ssm_sizes,block_len,expected_ratio",
[
# Nemotron 30B TP=1: ceil((36864 + 2097152) / 8192) = 261
((36864, 2097152), 8192, 261),
# Nemotron 30B TP=2: ceil((18432 + 1048576) / 4096) = 261
((18432, 1048576), 4096, 261),
# Nemotron 30B TP=4: ceil((9216 + 524288) / 4096) = 131
((9216, 524288), 4096, 131),
],
)
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_mamba_phys_ratio is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
compute_mamba_phys_ratio,
)
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio

View File

@@ -5,7 +5,7 @@ KV cache helper for store.
"""
from collections.abc import Iterator
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, cast
import torch
@@ -516,6 +516,338 @@ class TpKVTopology:
return cache if self.split_k_and_v else [cache]
# ---- Mamba-HMA hetero-TP transfer config ----
#
# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be
# replicated across P ranks (when P_TP > num_kv_heads), but Mamba
# conv/SSM state is almost always uniquely sharded per P rank. So the
# number of P ranks D must read from can differ between FA and Mamba,
# and they must be handled separately.
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
"""Physical KV head range stored in a rank's KV cache tensor.
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
distributed **contiguously** (matching vLLM's GQA weight partitioning):
consecutive ranks share a head before moving to the next one.
"""
if tp_size <= num_heads:
assert num_heads % tp_size == 0
per_rank = num_heads // tp_size
return range(rank * per_rank, (rank + 1) * per_rank)
else:
h = rank * num_heads // tp_size
return range(h, h + 1)
def _range_overlap(a: range, b: range) -> range:
start = max(a.start, b.start)
stop = min(a.stop, b.stop)
return range(start, max(start, stop))
@dataclass
class HeteroTPTransferConfig:
"""Precomputed transfer plan for one (D rank, P engine) pair.
Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models
where FA and mamba require different splitting factors. Could be extended
to other model types that need non-uniform hetero-TP transfer sizing.
All descriptor sizes are computed here. The guarantee is:
local_entry_size == remote_entry_size (for NIXL)
Attributes that start with ``fa_`` concern FlashAttention KV cache.
Attributes that start with ``mamba_`` concern Mamba conv/SSM state.
"""
# ---- Input parameters (from handshake) ----
tp_ratio: int
K: int # total_num_kv_heads (before TP sharding)
d_tp: int # D engine's tensor_parallel_size
p_tp: int # P engine's tensor_parallel_size
d_rank: int # this D worker's TP rank
use_mla: bool
# Per-layer block lengths (bytes, K+V combined for blocks_first).
# Uniform across layers for current models.
d_block_len: int # D's block_len_per_layer (representative)
p_block_len: int # P's block_len_per_layer (from handshake)
is_blocks_first: bool # kv_topo.is_kv_layout_blocks_first
# ---- Derived: computed in __post_init__ ----
#
# Physical heads per rank (what the KV tensor actually stores)
d_physical_heads: int = field(init=False)
p_physical_heads: int = field(init=False)
# How many distinct P ranks D needs for FA data
physical_fa_num_reads: int = field(init=False)
# Which P ranks contribute unique FA heads (ordered by head index)
fa_read_targets: list[int] = field(init=False)
# All P ranks needed for mamba (always abs_tp for tp_ratio < 0)
mamba_num_reads: int = field(init=False)
# All P ranks this D rank communicates with (FA mamba)
transfer_targets: list[int] = field(init=False)
# FA descriptor entry size (K or V side, for blocks_first layout)
# Guaranteed: fa_entry_size is the SAME for local handle AND remote desc.
fa_entry_size: int = field(init=False)
# Replication flags
is_d_replicated: bool = field(init=False)
is_p_replicated: bool = field(init=False)
# Pre-built set for fast lookup
_fa_target_set: frozenset[int] = field(init=False, repr=False)
# Map: P rank → index in fa_read_targets (for head slot offset)
_fa_target_index: dict[int, int] = field(init=False, repr=False)
def __post_init__(self) -> None:
K = self.K
self.is_d_replicated = self.d_tp > K
self.is_p_replicated = self.p_tp > K
self.d_physical_heads = max(1, K // self.d_tp)
self.p_physical_heads = max(1, K // self.p_tp)
abs_tp = -self.tp_ratio if self.tp_ratio < 0 else 1
# ---- Mamba range (computed first so FA can prefer ranks in it) ----
mamba_range: range | None = None
if self.tp_ratio < 0:
mamba_range = range(self.d_rank * abs_tp, (self.d_rank + 1) * abs_tp)
# ---- FA read targets ----
if self.use_mla or self.tp_ratio >= 0:
self.physical_fa_num_reads = 1
self.fa_read_targets = (
[0]
if self.use_mla
# Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio).
else [
self.d_rank // self.tp_ratio if self.tp_ratio > 0 else self.d_rank
]
)
else:
d_needs = _physical_head_range(self.d_tp, K, self.d_rank)
# When mamba range exists, prefer P ranks within it so that
# FA targets are a subset of mamba transfer_targets (avoids
# orphaned FA targets outside the transfer loop).
search_range = mamba_range if mamba_range is not None else range(self.p_tp)
seen: set[tuple[int, int]] = set()
targets: list[int] = []
for p in search_range:
p_has = _physical_head_range(self.p_tp, K, p)
ov = _range_overlap(d_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
targets.append(p)
if not targets:
# Fallback: search globally (should not happen in practice)
for p in range(self.p_tp):
p_has = _physical_head_range(self.p_tp, K, p)
ov = _range_overlap(d_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
targets.append(p)
self.fa_read_targets = targets
self.physical_fa_num_reads = len(targets)
self._fa_target_set = frozenset(self.fa_read_targets)
self._fa_target_index = {r: i for i, r in enumerate(self.fa_read_targets)}
# ---- Mamba targets ----
if mamba_range is not None and abs_tp > self.physical_fa_num_reads:
self.mamba_num_reads = abs_tp
self.transfer_targets = list(mamba_range)
else:
self.mamba_num_reads = self.physical_fa_num_reads
self.transfer_targets = list(self.fa_read_targets)
# ---- FA entry size ----
# For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V).
# Use min(D, P) because D indexes into P when tp_ratio > 0,
# and P is the natural unit when tp_ratio < 0.
effective_block_len = min(self.d_block_len, self.p_block_len)
if self.is_blocks_first:
self.fa_entry_size = effective_block_len // 2
else:
self.fa_entry_size = effective_block_len
self._validate()
def _validate(self) -> None:
"""Cross-check internal consistency."""
if self.is_d_replicated and self.is_p_replicated and self.tp_ratio > 0:
logger.info(
"Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. "
"Using d_rank // tp_ratio routing with relative head offset.",
self.d_tp,
self.p_tp,
self.K,
)
# FA targets must be a subset of transfer_targets
tt_set = set(self.transfer_targets)
for t in self.fa_read_targets:
if t not in tt_set:
logger.error(
"FA target P rank %d is NOT in transfer_targets %s. "
"This will cause missed FA reads!",
t,
self.transfer_targets,
)
# For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half
if (
self.is_blocks_first
and self.tp_ratio < 0
and self.physical_fa_num_reads > 0
):
d_k_half = self.d_block_len // 2
p_k_half = self.p_block_len // 2
expected_local = d_k_half // self.physical_fa_num_reads
if expected_local != p_k_half:
logger.warning(
"FA size mismatch: D_K_half=%d / reads=%d = %d, "
"but P_K_half=%d. This may indicate a head count or "
"Mamba-HMA inflation inconsistency.",
d_k_half,
self.physical_fa_num_reads,
expected_local,
p_k_half,
)
# ---- Query methods ----
def should_skip_fa(self, p_rank: int) -> bool:
"""Whether to skip FA groups for this P rank (mamba-only transfer)."""
return p_rank not in self._fa_target_set
def fa_head_slot(self, p_rank: int) -> int:
"""Index into D's FA block for this P rank's head data.
For P ranks in fa_read_targets, returns 0, 1, ..., reads-1.
For P ranks NOT in fa_read_targets (replicated duplicates),
returns the slot of the matching FA target with the same head.
"""
if p_rank in self._fa_target_index:
return self._fa_target_index[p_rank]
# Duplicate head: find which fa_target has the same physical head
p_head = _physical_head_range(self.p_tp, self.K, p_rank)
for target in self.fa_read_targets:
t_head = _physical_head_range(self.p_tp, self.K, target)
if _range_overlap(p_head, t_head):
return self._fa_target_index[target]
return 0 # fallback
def fa_rank_offset(self, remote_kv_block_len: int) -> int:
"""Byte offset into P's FA block for this D rank.
When D is replicated (D_TP > K), multiple D ranks share a head.
Computes offset *relative to the target P rank's first head*
so it works regardless of how many heads P has.
When neither side replicates, falls back to tp_rank % tp_ratio.
Returns 0 when D does not index into P's block.
"""
if self.use_mla or self.tp_ratio <= 0:
return 0
if self.is_d_replicated:
d_head = self.d_rank * self.K // self.d_tp
p_rank = self.fa_read_targets[0]
p_start = p_rank * self.K // self.p_tp
return (d_head - p_start) * remote_kv_block_len
return self.d_rank % self.tp_ratio * remote_kv_block_len
@property
def needs_split_handles(self) -> bool:
"""Whether per-P-rank split handles are needed.
True when FA and mamba have different read counts, requiring
different splitting factors in the local handle.
"""
return self.tp_ratio < 0 and not self.use_mla and len(self.transfer_targets) > 1
def compute_split_handle_data(
self,
src_blocks_data: list[tuple[int, int, int]],
num_fa_descs: int,
abs_tp: int,
) -> list[list[tuple[int, int, int]]]:
"""Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles.
FA descriptors (indices < num_fa_descs) are sliced by
``physical_fa_num_reads``; mamba descriptors are sliced uniformly
by ``abs_tp``.
Returns one list of triples per transfer target.
"""
all_handle_data: list[list[tuple[int, int, int]]] = []
for p_idx, p_rank in enumerate(self.transfer_targets):
handle_data: list[tuple[int, int, int]] = []
skip_fa = self.should_skip_fa(p_rank)
fa_slot = self.fa_head_slot(p_rank) if not skip_fa else 0
for j, (addr, local_len, tp) in enumerate(src_blocks_data):
if j < num_fa_descs:
assert self.physical_fa_num_reads >= 1
fa_chunk = local_len // self.physical_fa_num_reads
handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, tp))
else:
mamba_chunk = local_len // abs_tp
handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, tp))
all_handle_data.append(handle_data)
return all_handle_data
def filter_block_ids_for_rank(
self,
remote_rank: int,
local_ids: BlockIds,
remote_ids: BlockIds,
is_mamba_group: list[bool],
) -> tuple[BlockIds, BlockIds]:
"""Zero out FA groups for P ranks outside fa_read_targets.
Returns (filtered_local_ids, filtered_remote_ids). When the
remote rank carries FA data for this D rank, returns the inputs
unchanged.
"""
if not self.should_skip_fa(remote_rank):
return local_ids, remote_ids
num_groups = len(local_ids)
filtered_local: list[list[int]] = [
[] if not is_mamba_group[g] else local_ids[g] for g in range(num_groups)
]
filtered_remote: list[list[int]] = [
[] if not is_mamba_group[g] else remote_ids[g] for g in range(num_groups)
]
return filtered_local, filtered_remote
def describe(self) -> str:
"""One-line summary for logging."""
return (
f"HeteroTPTransferConfig("
f"tp_ratio={self.tp_ratio}, K={self.K}, "
f"d_tp={self.d_tp}, p_tp={self.p_tp}, d_rank={self.d_rank}, "
f"physical_fa_reads={self.physical_fa_num_reads}, "
f"mamba_reads={self.mamba_num_reads}, "
f"fa_targets={self.fa_read_targets}, "
f"transfer_targets={self.transfer_targets}, "
f"fa_entry_size={self.fa_entry_size}, "
f"d_block_len={self.d_block_len}, p_block_len={self.p_block_len})"
)
def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> list[type[AttentionBackend]]:
@@ -559,3 +891,50 @@ def get_current_attn_backend(
) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0]
# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig
# into a single engine-agnostic TransferTopology class.
# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data.
#
# @dataclass
# class EngineTransferInfo:
# """Per-remote-engine transfer state, computed at handshake."""
# p_tp: int
# tp_ratio: int
# p_block_len: int
# block_size: int
# # Mamba-specific (None for non-mamba models)
# fa_read_targets: list[int] | None = None
# transfer_targets: list[int] | None = None
# physical_fa_num_reads: int | None = None
# mamba_num_reads: int | None = None
# fa_entry_size: int | None = None
#
# class TransferTopology:
# """Single source of truth for TP topology + transfer sizing."""
# # Shared (set once at init, replaces duplicate fields)
# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank
# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp
# total_num_kv_heads: int # == HeteroTP.K
# is_mla: bool # == HeteroTP.use_mla
# is_mamba: bool
# is_blocks_first: bool # == HeteroTP.is_blocks_first
# d_block_len: int
#
# # Per-engine (populated via register_engine() at handshake)
# _engines: dict[EngineId, EngineTransferInfo]
#
# def register_engine(self, engine_id, p_tp, p_block_len, ...): ...
#
# # General (from TpKVTopology)
# def tp_ratio(self, engine_id) -> int: ...
# def target_remote_ranks(self, engine_id) -> list[int]: ...
# def is_kv_replicated(self, engine_id) -> bool: ...
#
# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba)
# def fa_rank_offset(self, engine_id, block_len) -> int: ...
# def physical_fa_num_reads(self, engine_id) -> int: ...
# def transfer_targets(self, engine_id) -> list[int]: ...
# def should_skip_fa(self, engine_id, p_rank) -> bool: ...
# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ...

View File

@@ -25,6 +25,7 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import (
BlockIds,
EngineId,
HeteroTPTransferConfig,
TpKVTopology,
get_current_attn_backend,
get_current_attn_backends,
@@ -47,12 +48,18 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
PromMetric,
PromMetricT,
)
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
MambaConvSplitInfo,
compute_mamba_phys_ratio,
derive_mamba_conv_split,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
@@ -1038,7 +1045,7 @@ class NixlConnectorWorker:
}
self.hma_group_size = len(kv_cache_config.kv_cache_tensors)
# Mamba metadata
# ---- Mamba model state (derived from model config) ----
self._is_mamba_group = [
isinstance(group.kv_cache_spec, MambaSpec)
for group in kv_cache_config.kv_cache_groups
@@ -1065,6 +1072,17 @@ class NixlConnectorWorker:
ssm_shape.numel() * ssm_nbytes,
)
self._mamba_ssm_size = mamba_ssm_size
# Conv state sub-projection decomposition (None when no Mamba).
# The 3-read transfer requires DS (dim, state_len) conv layout so
# that x/B/C sub-projections are contiguous in memory.
self._conv_decomp: MambaConvSplitInfo | None = None
if self._has_mamba:
assert is_conv_state_dim_first(), (
"3-read Mamba conv transfer requires DS conv state layout. "
"Set VLLM_SSM_CONV_STATE_LAYOUT=DS"
)
local_tp = vllm_config.parallel_config.tensor_parallel_size
self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp)
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
@@ -1175,6 +1193,16 @@ class NixlConnectorWorker:
self.dst_num_blocks: dict[EngineId, int] = {}
self._registered_descs: list[Any] = []
# ---- Mamba-HMA per-engine state (only used when self._has_mamba) ----
# Per-engine transfer config (source of truth for FA/mamba sizing).
self._transfer_configs: dict[str, HeteroTPTransferConfig] = {}
# NOTE (ZhanqiuHu): _mamba_phys_ratio MUST be per-engine.
# compute_mamba_phys_ratio = ceil((conv_bytes + ssm_bytes) / block_len)
# where conv/ssm bytes are per-TP-rank (dimension-sharded). With
# heterogeneous TP the per-rank sizes differ, so the ratio differs:
# e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261.
self._mamba_phys_ratio: dict[EngineId, int] = {}
# In progress transfers.
# [req_id -> list[handle]]
self._recving_metadata: dict[ReqId, ReqMeta] = {}
@@ -1701,8 +1729,7 @@ class NixlConnectorWorker:
# then duplicate it logically to be able to index SSM/Conv separately.
self.num_regions *= 2
# TODO (NickLucche) Adapt to different descs views (engine_id->tp_rank) to
# support heterogeneous TP.
# Total local FA descriptors (boundary between FA and mamba descs).
self.num_descs = self.num_regions * self.num_blocks
descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type)
@@ -1715,6 +1742,9 @@ class NixlConnectorWorker:
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._has_mamba:
self._mamba_phys_ratio[self.engine_id] = (
self._physical_blocks_per_logical_kv_block
)
logger.info(
"Hybrid SSM registration: num_blocks=%s, "
"logical_num_blocks=%s, ratio=%s, num_regions=%s, "
@@ -1755,6 +1785,149 @@ class NixlConnectorWorker:
agent_metadata_bytes=encoder.encode(agent_metadata),
)
def _build_mamba_local(
self,
base_addresses: list[int],
block_size_ratio: int,
) -> list[tuple[int, int, int]]:
"""Build 4 desc regions (x, B, C, ssm) per layer for local mamba
blocks, enabling the 3-read transfer with DS conv layout."""
assert block_size_ratio == 1, (
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
f"Got block_size_ratio={block_size_ratio}."
)
assert self._conv_decomp is not None
conv_offsets = self._conv_decomp.local_conv_offsets
conv_size, ssm_size = self._mamba_ssm_size
num_blocks = self._logical_num_blocks * block_size_ratio
phys_ratio = self._physical_blocks_per_logical_kv_block
result: list[tuple[int, int, int]] = []
for i, base_addr in enumerate(base_addresses):
page_stride = self.block_len_per_layer[i] // block_size_ratio * phys_ratio
for off, sz in conv_offsets:
for blk in range(num_blocks):
result.append(
(base_addr + blk * page_stride + off, sz, self.device_id)
)
# SSM temporal state follows the conv state.
for blk in range(num_blocks):
result.append(
(
base_addr + blk * page_stride + conv_size,
ssm_size,
self.device_id,
)
)
return result
def _build_fa_remote_for_mamba(
self,
nixl_agent_meta: NixlAgentMetadata,
transfer_cfg: HeteroTPTransferConfig,
block_size_ratio: int,
kv_topo: TpKVTopology,
) -> list[tuple[int, int, int]]:
"""Build remote FA descriptors for mamba models.
Uses transfer_cfg for GQA-aware FA divisor and head-based rank offset
instead of the standard uniform tp_ratio split.
"""
assert block_size_ratio == 1, (
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
f"Got block_size_ratio={block_size_ratio}."
)
# TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA
# hetero-TP logic stabilizes.
tp_ratio = transfer_cfg.tp_ratio
result: list[tuple[int, int, int]] = []
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
local_block_len = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=True, mamba_view=False
)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1:
local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
local_block_len = local_block_len // transfer_cfg.physical_fa_num_reads
rank_offset = transfer_cfg.fa_rank_offset(remote_kv_block_len)
num_blocks = nixl_agent_meta.num_blocks
page_size = nixl_agent_meta.block_lens[i]
for block_id in range(num_blocks):
block_offset = block_id * page_size
addr = base_addr + block_offset + rank_offset
result.append((addr, local_block_len, nixl_agent_meta.device_id))
if kv_topo.is_kv_layout_blocks_first:
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=False
)
if tp_ratio < 0 and not self.use_mla:
second_split = second_split // transfer_cfg.physical_fa_num_reads
for block_id in range(num_blocks):
block_offset = block_id * page_size
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
result.append((v_addr, second_split, nixl_agent_meta.device_id))
return result
def _build_mamba_remote(
self,
nixl_agent_meta: NixlAgentMetadata,
tp_ratio: int,
) -> list[tuple[int, int, int]]:
"""Build 4 remote desc regions (x, B, C, ssm) per layer for
the 3-read transfer. For hetero-TP, each D rank reads only its
sub-projection slice from the P rank."""
assert self._conv_decomp is not None
effective_ratio = max(tp_ratio, 1)
# Mamba conv state is always TP-sharded, even when attention KV
# is replicated (num_kv_heads < tp_size).
local_offset = self.tp_rank % effective_ratio
conv_size_remote = nixl_agent_meta.ssm_sizes[0]
if tp_ratio >= 1:
# D_TP >= P_TP: P page is larger, D reads its slice.
conv_offsets = self._conv_decomp.remote_conv_offsets(
local_offset, effective_ratio
)
ssm_read_size = self._mamba_ssm_size[1]
else:
# NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages
# are smaller than D's. self._conv_decomp has D-sized dimensions,
# but we need P-sized offsets. Scale down by |tp_ratio|.
abs_ratio = -tp_ratio
xb_p = self._conv_decomp.x_bytes // abs_ratio
bb_p = self._conv_decomp.b_bytes // abs_ratio
conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)]
ssm_read_size = nixl_agent_meta.ssm_sizes[1]
remote_ratio = self._mamba_phys_ratio[nixl_agent_meta.engine_id]
num_blocks = nixl_agent_meta.num_blocks // remote_ratio
device_id = nixl_agent_meta.device_id
result: list[tuple[int, int, int]] = []
# NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case
# block lengths vary across layers (e.g. MLA).
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
page_stride = nixl_agent_meta.block_lens[i] * remote_ratio
for off, sz in conv_offsets:
for blk in range(num_blocks):
result.append((base_addr + blk * page_stride + off, sz, device_id))
# SSM temporal state is also TP-sharded on the heads dimension.
for blk in range(num_blocks):
ssm_addr = (
base_addr
+ blk * page_stride
+ conv_size_remote
+ local_offset * ssm_read_size
)
result.append((ssm_addr, ssm_read_size, device_id))
return result
def register_local_xfer_handler(
self,
block_size: int,
@@ -1823,13 +1996,22 @@ class NixlConnectorWorker:
self.device_id,
)
# NOTE (ZhanqiuHu): mamba=True path in register_blocks is not used
# right now — we use _build_mamba_local instead for the 3-read
# approach. However, we might still need this as a fallback for homogeneous TP.
register_blocks(blocks_data, mamba=False)
if self._has_mamba:
assert self.num_descs == len(blocks_data)
logger.debug(
"Registering additional %s local Mamba blocks", len(blocks_data)
# TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read split is
# unnecessary — a single conv desc per block suffices. Consider
# adding a fast path that falls back to the standard 2-region
# registration (register_blocks mamba=True) when no hetero-TP
# remote has been seen. Currently we always register 4 regions
# because local descs are created before knowing the remote TP.
logger.debug("Registering local Mamba descriptors (4 regions/layer)")
blocks_data.extend(
self._build_mamba_local(local_base_addresses, block_size_ratio)
)
register_blocks(blocks_data, mamba=True)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
@@ -1880,6 +2062,9 @@ class NixlConnectorWorker:
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
so that the whole cache is shared by "tp_ratio" D TP workers.
For Mamba hetero-TP, both tp_ratio > 0 (D_TP > P_TP) and
tp_ratio < 0 (P_TP > D_TP) are supported by the 3-read transfer.
""" # noqa: E501
engine_id = nixl_agent_meta.engine_id
# TODO re-evaluate refreshing for scaling/recovery
@@ -1915,6 +2100,10 @@ class NixlConnectorWorker:
if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
if self._has_mamba:
self._mamba_phys_ratio[engine_id] = compute_mamba_phys_ratio(
nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0]
)
# Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
@@ -1931,6 +2120,21 @@ class NixlConnectorWorker:
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
)
# Create transfer config (single source of truth for descriptor sizes).
if self._has_mamba and engine_id not in self._transfer_configs:
self._transfer_configs[engine_id] = HeteroTPTransferConfig(
tp_ratio=tp_ratio,
K=kv_topo.total_num_kv_heads,
d_tp=self.world_size,
p_tp=remote_tp_size,
d_rank=self.tp_rank,
use_mla=self.use_mla,
d_block_len=self.block_len_per_layer[0],
p_block_len=nixl_agent_meta.block_lens[0],
is_blocks_first=kv_topo.is_kv_layout_blocks_first,
)
logger.info("Created %s", self._transfer_configs[engine_id].describe())
logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id,
@@ -1947,21 +2151,48 @@ class NixlConnectorWorker:
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
abs_tp = -tp_ratio
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
for i in range(-tp_ratio):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len = local_block_len // (-tp_ratio)
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
if self._has_mamba:
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
if transfer_cfg.needs_split_handles:
# Mamba-HMA: FA and Mamba use different split factors.
for handle_data in transfer_cfg.compute_split_handle_data(
self.src_blocks_data, self.num_descs, abs_tp
):
descs = self.nixl_wrapper.get_xfer_descs(
handle_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs
)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
logger.info(
"Mamba-HMA split handles: targets=%s, fa_reads=%s, "
"fa_entry=%s, mamba_reads=%s, num_descs=%s",
transfer_cfg.transfer_targets,
transfer_cfg.physical_fa_num_reads,
transfer_cfg.fa_entry_size,
transfer_cfg.mamba_num_reads,
self.num_descs,
)
else:
# Original path: uniform divide by abs_tp (non-Mamba-HMA).
for i in range(abs_tp):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
remote_block_len = local_block_len // abs_tp
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
### Register remote agent memory regions
blocks_data = []
@@ -2044,13 +2275,33 @@ class NixlConnectorWorker:
self.tp_rank,
)
register_remote_blocks(blocks_data, mamba=False)
if self._has_mamba:
# Create extra descs for the Mamba "view" of the same KV cache tensors.
# Mamba-HMA: separate FA registration with GQA-aware sizing,
# plus mamba 3-read registration for the Mamba "view" of the
# same KV cache tensors.
logger.debug(
"Registering additional %s remote Mamba blocks", len(blocks_data)
"Registering remote Mamba blocks for engine %s rank %s",
engine_id,
remote_tp_rank,
)
register_remote_blocks(blocks_data, mamba=True)
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
blocks_data.extend(
self._build_fa_remote_for_mamba(
nixl_agent_meta,
transfer_cfg,
block_size_ratio,
kv_topo,
)
)
blocks_data.extend(
self._build_mamba_remote(
nixl_agent_meta,
tp_ratio,
)
)
else:
register_remote_blocks(blocks_data, mamba=False)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
@@ -2083,17 +2334,17 @@ class NixlConnectorWorker:
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id
)
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
# num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba.
# Mamba models can have replicated FA KV with tp_ratio < 0.
if not self._has_mamba:
assert not (
tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)
)
if self._is_hma_required:
assert block_size_ratio == 1, (
"HMA does not support different remote block size yet"
)
# Mamba additional constraints
if self._has_mamba:
assert tp_ratio == 1, "Mamba does not support heterogeneous TP yet"
kv_cache_layout = (
self.kv_cache_layout
if not self.use_host_buffer
@@ -2138,11 +2389,14 @@ class NixlConnectorWorker:
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
# With replicated KV cache, only the number of blocks can differ.
for i in range(len(self.block_len_per_layer)):
assert (
self.block_len_per_layer[i] // block_size_ratio
== nixl_agent_meta.block_lens[i]
), "KV cache sizes must match between P and D when replicated"
# TODO (ZhanqiuHu): For mamba models, validate FA and mamba
# block_lens separately.
if not self._has_mamba:
for i in range(len(self.block_len_per_layer)):
assert (
self.block_len_per_layer[i] // block_size_ratio
== nixl_agent_meta.block_lens[i]
), "KV cache sizes must match between P and D when replicated"
else:
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
@@ -2150,25 +2404,31 @@ class NixlConnectorWorker:
"All remote layers must have the same block size"
)
if tp_ratio > 0:
# Remote tp is smaller: remote block_len size is bigger
assert (
remote_block_len
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
else:
assert block_size_ratio == 1, (
"Different local/remote block sizes are not supported when"
" P TP > D TP."
)
# Remote tp is bigger: remote block_len size is smaller
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
# HMA hybrid models (mamba+attention) pad block_len to
# max(attn_page, mamba_page), so the linear tp_ratio scaling
# assumption only holds for pure-attention models.
if not self._has_mamba:
if tp_ratio > 0:
assert (
remote_block_len
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N,"
" local_kv_heads*tp_ratio, page_size, head_dim] and "
"same dtype."
)
else:
assert block_size_ratio == 1, (
"Different local/remote block sizes are not supported"
" when P TP > D TP."
)
assert remote_block_len == self.block_len_per_layer[0] // (
-tp_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N,"
" local_kv_heads/tp_ratio, page_size, head_dim] and "
"same dtype."
)
# TP workers that handhshake with same remote have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
@@ -2471,9 +2731,8 @@ class NixlConnectorWorker:
meta.local_block_ids
)
assert meta.remote is not None
meta.remote.block_ids = self._logical_to_kernel_block_ids(
meta.remote.block_ids
)
# Remote block IDs are kept logical here; expanded in
# _read_blocks_for_req using the remote engine's phys ratio.
remote_engine_id = meta.remote.engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
@@ -2525,6 +2784,13 @@ class NixlConnectorWorker:
meta.remote.engine_id
)
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
if self._has_mamba:
# Expand remote logical → kernel block IDs.
meta.remote.block_ids = self._logical_to_remote_kernel_block_ids(
meta.remote.block_ids,
self._mamba_phys_ratio[meta.remote.engine_id],
)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
if self.use_mla and tp_ratio < 0 and i > 0:
@@ -2558,12 +2824,26 @@ class NixlConnectorWorker:
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
remote_rank
]
local_ids: BlockIds = meta.local_physical_block_ids
remote_ids: BlockIds = meta.remote.block_ids
if self._has_mamba:
# Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets.
transfer_cfg = self._transfer_configs.get(meta.remote.engine_id)
assert transfer_cfg is not None
local_ids, remote_ids = transfer_cfg.filter_block_ids_for_rank(
remote_rank,
local_ids,
remote_ids,
self._is_mamba_group,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
local_block_ids=local_ids,
remote_block_ids=remote_ids,
remote_rank=remote_rank,
local_xfer_side_handle=local_xfer_side_handle,
remote_xfer_side_handle=remote_xfer_side_handle,
@@ -2663,9 +2943,12 @@ class NixlConnectorWorker:
for i, remote_group in enumerate(remote_block_ids):
num_remote_blocks = len(remote_group)
num_local_blocks = len(local_block_ids[i])
assert num_local_blocks <= num_remote_blocks
if not self._is_mamba_group[i]:
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
if num_local_blocks < num_remote_blocks:
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
remote_block_ids[i] = remote_group[-num_local_blocks:]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
@@ -2781,16 +3064,22 @@ class NixlConnectorWorker:
# This is like having two "low-level views" of the same storage.
# `num_fa_descs` offset must be computed per-engine since P and D can
# have different num_blocks (and thus different FA descs counts).
ratio = self._physical_blocks_per_logical_kv_block
# SSM may register fewer num_blocks than FA
ratio = self._mamba_phys_ratio[engine_id]
logical_blocks = num_blocks // ratio
num_fa_descs = self.num_regions * num_blocks
# 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm).
mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None]
all_descs = []
for i, group in enumerate(block_ids):
stride = logical_blocks if self._is_mamba_group[i] else num_blocks
group_arr = np.asarray(group)[None, :]
offset = num_fa_descs if self._is_mamba_group[i] else 0
all_descs.append((region_ids * stride + group_arr + offset).flatten())
if self._is_mamba_group[i]:
all_descs.append(
(
mamba_region_ids * logical_blocks + group_arr + num_fa_descs
).flatten()
)
else:
all_descs.append((region_ids * num_blocks + group_arr).flatten())
return np.concatenate(all_descs)
def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds:
@@ -2818,6 +3107,36 @@ class NixlConnectorWorker:
for i, group in enumerate(block_ids)
]
def _logical_to_remote_kernel_block_ids(
self, block_ids: BlockIds, remote_ratio: int
) -> BlockIds:
"""Map logical block IDs to physical kernel block IDs on the remote.
Args:
block_ids: per-group lists of logical block IDs.
remote_ratio: remote engine's physical blocks per logical block.
Returns:
Same structure with FA groups expanded (each logical block L
becomes kernel blocks [L*remote_ratio .. L*remote_ratio +
local_ratio - 1]). Mamba groups are passed through unchanged.
"""
local_ratio = self._physical_blocks_per_logical_kv_block
if remote_ratio == 1:
return block_ids
local_arange = np.arange(local_ratio).reshape(1, -1)
group_specs = self.kv_cache_config.kv_cache_groups
result: list[list[int]] = []
for i, group in enumerate(block_ids):
if not isinstance(group_specs[i].kv_cache_spec, MambaSpec):
arr = np.array(group).reshape(-1, 1)
expanded = (arr * remote_ratio + local_arange).flatten()
result.append(expanded.tolist())
else:
# Mamba blocks are 1:1 logical-to-physical (no expansion).
result.append(group)
return result
def get_backend_aware_kv_block_len(
self, layer_idx: int, first_split: bool = True, mamba_view: bool = False
) -> int:

View File

@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Mamba conv-state sub-projection decomposition for the 3-read transfer.
With DS conv state layout (dim, state_len), x/B/C sub-projections are
contiguous in memory. Each D rank reads its x, B, C slices via 3
separate RDMA transfers — no P-side permutation needed.
"""
import math
from dataclasses import dataclass
import torch
from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first
from vllm.v1.kv_cache_interface import MambaSpec
@dataclass(frozen=True)
class MambaConvSplitInfo:
"""Per-rank byte sizes of x, B, C sub-projections in the Mamba conv state.
Used by both P and D sides for NIXL descriptor registration.
All fields are LOCAL to this engine's TP (already divided by TP size).
DS memory layout within one page (contiguous in memory):
|--- x (x_local * conv_rows) ---|- B (b_local * conv_rows) -|- C -|
"""
conv_rows: int # conv_kernel - 1 (typically 3)
x_local: int # intermediate_size / TP (columns for x)
b_local: int # groups_ss / TP (columns for B; C is same size)
conv_dtype_size: int # bytes per element (e.g. 2 for float16)
@property
def conv_dim_local(self) -> int:
"""Total conv columns per rank: x + B + C."""
return self.x_local + 2 * self.b_local
@property
def x_bytes(self) -> int:
"""Byte size of the x sub-projection for one rank."""
return self.x_local * self.conv_rows * self.conv_dtype_size
@property
def b_bytes(self) -> int:
"""Byte size of the B (or C) sub-projection for one rank."""
return self.b_local * self.conv_rows * self.conv_dtype_size
@property
def local_conv_offsets(self) -> list[tuple[int, int]]:
"""(byte_offset, byte_size) of x, B, C within this engine's page.
Used by both P and D for local descriptor registration.
"""
xb = self.x_bytes
bb = self.b_bytes
return [(0, xb), (xb, bb), (xb + bb, bb)]
def remote_conv_offsets(
self, local_rank_offset: int, tp_ratio: int
) -> list[tuple[int, int]]:
"""(byte_offset, byte_size) of this D rank's x, B, C slice within
one P page.
Used by D side only, during remote descriptor registration.
Args:
local_rank_offset: which slice this D rank reads.
tp_ratio > 0: tp_rank % tp_ratio (selects slice of P's page).
tp_ratio < 0: always 0 (read P's full page).
tp_ratio: effective ratio (>= 1 when D_TP > P_TP, 1 when
P_TP > D_TP since each P rank is read in full).
"""
xb = self.x_bytes
bb = self.b_bytes
xr = xb * tp_ratio # full remote x section in bytes
br = bb * tp_ratio # full remote B section in bytes
return [
(local_rank_offset * xb, xb),
(xr + local_rank_offset * bb, bb),
(xr + br + local_rank_offset * bb, bb),
]
def derive_mamba_conv_split(
mamba_spec: MambaSpec,
local_tp: int,
) -> MambaConvSplitInfo:
"""Derive per-rank x/B/C byte sizes from a MambaSpec.
Called once at init on both P and D. Decomposes the conv dimension
(= intermediate_size + 2 * groups_ss) into its x, B, C parts.
Args:
mamba_spec: MambaSpec whose shapes are:
shapes[0] = conv state: (conv_dim_local, conv_rows) in DS layout.
shapes[1] = SSM temporal: (local_num_heads, head_dim).
local_tp: this engine's tensor-parallel size.
Returns:
MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, and
conv_dtype_size.
"""
if mamba_spec.mamba_type != "mamba2":
raise NotImplementedError(
f"3-read conv transfer only supports Mamba2 models, "
f"got mamba_type={mamba_spec.mamba_type!r}. "
f"Mamba1 SSM temporal shape is (intermediate_size // tp, state_size) "
f"which cannot be used to reconstruct intermediate_size."
)
conv_shape = mamba_spec.shapes[0]
assert len(conv_shape) == 2, f"Expected 2D conv state shape, got {conv_shape}"
# NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted
# in nixl_connector __init__. Use it directly instead of heuristic detection.
assert is_conv_state_dim_first(), "3-read requires DS conv state layout"
local_conv_dim = conv_shape[0] # DS: (conv_dim_local, conv_rows)
conv_rows = conv_shape[1]
# NOTE (ZhanqiuHu): intermediate_size (= global x dim) is not stored
# in MambaSpec, so we reconstruct it from the SSM temporal state shape:
# shapes[1] = (local_num_heads, head_dim), already divided by TP.
head_dim = mamba_spec.shapes[1][1]
local_num_heads = mamba_spec.shapes[1][0]
intermediate_size = local_num_heads * local_tp * head_dim
# NOTE (ZhanqiuHu): global conv dim = intermediate_size + 2 * groups_ss,
# where groups_ss is the B (= C) dimension. B and C are always the same
# size, so we recover groups_ss from the remainder after subtracting x.
remainder = local_conv_dim * local_tp - intermediate_size
assert remainder > 0 and remainder % 2 == 0, (
f"Conv dim ({local_conv_dim}*tp={local_tp}) doesn't decompose into "
f"intermediate_size={intermediate_size} + 2*groups_ss. "
f"remainder={remainder}"
)
groups_ss = remainder // 2
conv_dtype_size = torch.tensor(
[],
dtype=mamba_spec.dtypes[0], # type: ignore[misc]
).element_size()
# Divide by TP to get per-rank column counts.
return MambaConvSplitInfo(
conv_rows=conv_rows,
x_local=intermediate_size // local_tp,
b_local=groups_ss // local_tp,
conv_dtype_size=conv_dtype_size,
)
def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int:
"""Derive _physical_blocks_per_logical_kv_block from remote metadata.
The remote engine's ratio is not sent directly in the handshake, so we
reconstruct it: total mamba state per logical block / block_len.
Args:
ssm_sizes: (conv_state_bytes, ssm_state_bytes) from NixlAgentMetadata.
block_len: the engine's block_len in bytes (from block_lens[0]).
"""
return math.ceil((ssm_sizes[0] + ssm_sizes[1]) / block_len)