[NIXL][Mamba][3/N] Heterogeneous TP: 3-read conv state transfer (#37635)
This commit is contained in:
@@ -19,9 +19,9 @@ dp_ep_configs=(
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
|
||||
)
|
||||
hybrid_ssm_configs=(
|
||||
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
|
||||
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
|
||||
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
||||
)
|
||||
sw_attn_configs=(
|
||||
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
|
||||
|
||||
@@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm():
|
||||
worker._has_mamba = True
|
||||
worker._is_mamba_group = [False, True]
|
||||
worker._physical_blocks_per_logical_kv_block = 1
|
||||
worker._mamba_phys_ratio = {engine_id: 1}
|
||||
worker.block_len_per_layer = [100]
|
||||
# num_descs = num_regions * num_blocks (no blocks_first doubling)
|
||||
worker.num_descs = 2 * num_blocks
|
||||
|
||||
@@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm():
|
||||
# FA group: stride=num_blocks=100, offset=0
|
||||
# region0: [3, 5], region1: [103, 105]
|
||||
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
|
||||
# offset=num_descs=200
|
||||
# region0: [201, 202], region1: [301, 302]
|
||||
expected = [3, 5, 103, 105, 201, 202, 301, 302]
|
||||
# offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm)
|
||||
# region0: [201, 202], region1: [301, 302],
|
||||
# region2: [401, 402], region3: [501, 502]
|
||||
expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502]
|
||||
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
|
||||
|
||||
|
||||
@@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch():
|
||||
worker._has_mamba = True
|
||||
worker._is_mamba_group = [False, True]
|
||||
worker._physical_blocks_per_logical_kv_block = ratio
|
||||
worker._mamba_phys_ratio = {engine_id: ratio}
|
||||
worker.block_len_per_layer = [100]
|
||||
worker.num_descs = 2 * num_blocks # 800
|
||||
|
||||
fa_blocks = [3, 7] # kernel-level block IDs
|
||||
@@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch():
|
||||
|
||||
# FA group: stride=num_blocks=400, offset=0
|
||||
# region0: [3, 7], region1: [403, 407]
|
||||
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
|
||||
# region0: [801, 802], region1: [901, 902]
|
||||
expected = [3, 7, 403, 407, 801, 802, 901, 902]
|
||||
# SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800,
|
||||
# 4 regions per Mamba layer (x, B, C, ssm)
|
||||
# region0: [801, 802], region1: [901, 902],
|
||||
# region2: [1001, 1002], region3: [1101, 1102]
|
||||
expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102]
|
||||
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
|
||||
|
||||
|
||||
@@ -418,3 +425,29 @@ def test_has_mamba_init(
|
||||
)
|
||||
assert scheduler._has_mamba is expected_has_mamba
|
||||
assert scheduler._is_hma_required is expected_is_hma
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"ssm_sizes,block_len,expected_ratio",
|
||||
[
|
||||
# Nemotron 30B TP=1: ceil((36864 + 2097152) / 8192) = 261
|
||||
((36864, 2097152), 8192, 261),
|
||||
# Nemotron 30B TP=2: ceil((18432 + 1048576) / 4096) = 261
|
||||
((18432, 1048576), 4096, 261),
|
||||
# Nemotron 30B TP=4: ceil((9216 + 524288) / 4096) = 131
|
||||
((9216, 524288), 4096, 131),
|
||||
],
|
||||
)
|
||||
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio):
|
||||
"""Verify that compute_mamba_phys_ratio is TP-dependent.
|
||||
|
||||
With dimension-sharded Mamba state, the ratio differs across TP sizes
|
||||
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
|
||||
_mamba_phys_ratio must be stored per-engine.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
|
||||
compute_mamba_phys_ratio,
|
||||
)
|
||||
|
||||
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio
|
||||
|
||||
@@ -5,7 +5,7 @@ KV cache helper for store.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import torch
|
||||
@@ -516,6 +516,338 @@ class TpKVTopology:
|
||||
return cache if self.split_k_and_v else [cache]
|
||||
|
||||
|
||||
# ---- Mamba-HMA hetero-TP transfer config ----
|
||||
#
|
||||
# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be
|
||||
# replicated across P ranks (when P_TP > num_kv_heads), but Mamba
|
||||
# conv/SSM state is almost always uniquely sharded per P rank. So the
|
||||
# number of P ranks D must read from can differ between FA and Mamba,
|
||||
# and they must be handled separately.
|
||||
|
||||
|
||||
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
|
||||
"""Physical KV head range stored in a rank's KV cache tensor.
|
||||
|
||||
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
|
||||
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
|
||||
distributed **contiguously** (matching vLLM's GQA weight partitioning):
|
||||
consecutive ranks share a head before moving to the next one.
|
||||
"""
|
||||
if tp_size <= num_heads:
|
||||
assert num_heads % tp_size == 0
|
||||
per_rank = num_heads // tp_size
|
||||
return range(rank * per_rank, (rank + 1) * per_rank)
|
||||
else:
|
||||
h = rank * num_heads // tp_size
|
||||
return range(h, h + 1)
|
||||
|
||||
|
||||
def _range_overlap(a: range, b: range) -> range:
|
||||
start = max(a.start, b.start)
|
||||
stop = min(a.stop, b.stop)
|
||||
return range(start, max(start, stop))
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeteroTPTransferConfig:
|
||||
"""Precomputed transfer plan for one (D rank, P engine) pair.
|
||||
|
||||
Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models
|
||||
where FA and mamba require different splitting factors. Could be extended
|
||||
to other model types that need non-uniform hetero-TP transfer sizing.
|
||||
|
||||
All descriptor sizes are computed here. The guarantee is:
|
||||
local_entry_size == remote_entry_size (for NIXL)
|
||||
|
||||
Attributes that start with ``fa_`` concern FlashAttention KV cache.
|
||||
Attributes that start with ``mamba_`` concern Mamba conv/SSM state.
|
||||
"""
|
||||
|
||||
# ---- Input parameters (from handshake) ----
|
||||
tp_ratio: int
|
||||
K: int # total_num_kv_heads (before TP sharding)
|
||||
d_tp: int # D engine's tensor_parallel_size
|
||||
p_tp: int # P engine's tensor_parallel_size
|
||||
d_rank: int # this D worker's TP rank
|
||||
use_mla: bool
|
||||
|
||||
# Per-layer block lengths (bytes, K+V combined for blocks_first).
|
||||
# Uniform across layers for current models.
|
||||
d_block_len: int # D's block_len_per_layer (representative)
|
||||
p_block_len: int # P's block_len_per_layer (from handshake)
|
||||
is_blocks_first: bool # kv_topo.is_kv_layout_blocks_first
|
||||
|
||||
# ---- Derived: computed in __post_init__ ----
|
||||
#
|
||||
# Physical heads per rank (what the KV tensor actually stores)
|
||||
d_physical_heads: int = field(init=False)
|
||||
p_physical_heads: int = field(init=False)
|
||||
|
||||
# How many distinct P ranks D needs for FA data
|
||||
physical_fa_num_reads: int = field(init=False)
|
||||
|
||||
# Which P ranks contribute unique FA heads (ordered by head index)
|
||||
fa_read_targets: list[int] = field(init=False)
|
||||
|
||||
# All P ranks needed for mamba (always abs_tp for tp_ratio < 0)
|
||||
mamba_num_reads: int = field(init=False)
|
||||
|
||||
# All P ranks this D rank communicates with (FA ∪ mamba)
|
||||
transfer_targets: list[int] = field(init=False)
|
||||
|
||||
# FA descriptor entry size (K or V side, for blocks_first layout)
|
||||
# Guaranteed: fa_entry_size is the SAME for local handle AND remote desc.
|
||||
fa_entry_size: int = field(init=False)
|
||||
|
||||
# Replication flags
|
||||
is_d_replicated: bool = field(init=False)
|
||||
is_p_replicated: bool = field(init=False)
|
||||
|
||||
# Pre-built set for fast lookup
|
||||
_fa_target_set: frozenset[int] = field(init=False, repr=False)
|
||||
# Map: P rank → index in fa_read_targets (for head slot offset)
|
||||
_fa_target_index: dict[int, int] = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
K = self.K
|
||||
self.is_d_replicated = self.d_tp > K
|
||||
self.is_p_replicated = self.p_tp > K
|
||||
|
||||
self.d_physical_heads = max(1, K // self.d_tp)
|
||||
self.p_physical_heads = max(1, K // self.p_tp)
|
||||
|
||||
abs_tp = -self.tp_ratio if self.tp_ratio < 0 else 1
|
||||
|
||||
# ---- Mamba range (computed first so FA can prefer ranks in it) ----
|
||||
mamba_range: range | None = None
|
||||
if self.tp_ratio < 0:
|
||||
mamba_range = range(self.d_rank * abs_tp, (self.d_rank + 1) * abs_tp)
|
||||
|
||||
# ---- FA read targets ----
|
||||
if self.use_mla or self.tp_ratio >= 0:
|
||||
self.physical_fa_num_reads = 1
|
||||
self.fa_read_targets = (
|
||||
[0]
|
||||
if self.use_mla
|
||||
# Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio).
|
||||
else [
|
||||
self.d_rank // self.tp_ratio if self.tp_ratio > 0 else self.d_rank
|
||||
]
|
||||
)
|
||||
else:
|
||||
d_needs = _physical_head_range(self.d_tp, K, self.d_rank)
|
||||
# When mamba range exists, prefer P ranks within it so that
|
||||
# FA targets are a subset of mamba transfer_targets (avoids
|
||||
# orphaned FA targets outside the transfer loop).
|
||||
search_range = mamba_range if mamba_range is not None else range(self.p_tp)
|
||||
seen: set[tuple[int, int]] = set()
|
||||
targets: list[int] = []
|
||||
for p in search_range:
|
||||
p_has = _physical_head_range(self.p_tp, K, p)
|
||||
ov = _range_overlap(d_needs, p_has)
|
||||
if len(ov) > 0:
|
||||
key = (ov.start, ov.stop)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
targets.append(p)
|
||||
if not targets:
|
||||
# Fallback: search globally (should not happen in practice)
|
||||
for p in range(self.p_tp):
|
||||
p_has = _physical_head_range(self.p_tp, K, p)
|
||||
ov = _range_overlap(d_needs, p_has)
|
||||
if len(ov) > 0:
|
||||
key = (ov.start, ov.stop)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
targets.append(p)
|
||||
self.fa_read_targets = targets
|
||||
self.physical_fa_num_reads = len(targets)
|
||||
|
||||
self._fa_target_set = frozenset(self.fa_read_targets)
|
||||
self._fa_target_index = {r: i for i, r in enumerate(self.fa_read_targets)}
|
||||
|
||||
# ---- Mamba targets ----
|
||||
if mamba_range is not None and abs_tp > self.physical_fa_num_reads:
|
||||
self.mamba_num_reads = abs_tp
|
||||
self.transfer_targets = list(mamba_range)
|
||||
else:
|
||||
self.mamba_num_reads = self.physical_fa_num_reads
|
||||
self.transfer_targets = list(self.fa_read_targets)
|
||||
|
||||
# ---- FA entry size ----
|
||||
# For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V).
|
||||
# Use min(D, P) because D indexes into P when tp_ratio > 0,
|
||||
# and P is the natural unit when tp_ratio < 0.
|
||||
effective_block_len = min(self.d_block_len, self.p_block_len)
|
||||
if self.is_blocks_first:
|
||||
self.fa_entry_size = effective_block_len // 2
|
||||
else:
|
||||
self.fa_entry_size = effective_block_len
|
||||
|
||||
self._validate()
|
||||
|
||||
def _validate(self) -> None:
|
||||
"""Cross-check internal consistency."""
|
||||
if self.is_d_replicated and self.is_p_replicated and self.tp_ratio > 0:
|
||||
logger.info(
|
||||
"Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. "
|
||||
"Using d_rank // tp_ratio routing with relative head offset.",
|
||||
self.d_tp,
|
||||
self.p_tp,
|
||||
self.K,
|
||||
)
|
||||
|
||||
# FA targets must be a subset of transfer_targets
|
||||
tt_set = set(self.transfer_targets)
|
||||
for t in self.fa_read_targets:
|
||||
if t not in tt_set:
|
||||
logger.error(
|
||||
"FA target P rank %d is NOT in transfer_targets %s. "
|
||||
"This will cause missed FA reads!",
|
||||
t,
|
||||
self.transfer_targets,
|
||||
)
|
||||
|
||||
# For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half
|
||||
if (
|
||||
self.is_blocks_first
|
||||
and self.tp_ratio < 0
|
||||
and self.physical_fa_num_reads > 0
|
||||
):
|
||||
d_k_half = self.d_block_len // 2
|
||||
p_k_half = self.p_block_len // 2
|
||||
expected_local = d_k_half // self.physical_fa_num_reads
|
||||
if expected_local != p_k_half:
|
||||
logger.warning(
|
||||
"FA size mismatch: D_K_half=%d / reads=%d = %d, "
|
||||
"but P_K_half=%d. This may indicate a head count or "
|
||||
"Mamba-HMA inflation inconsistency.",
|
||||
d_k_half,
|
||||
self.physical_fa_num_reads,
|
||||
expected_local,
|
||||
p_k_half,
|
||||
)
|
||||
|
||||
# ---- Query methods ----
|
||||
|
||||
def should_skip_fa(self, p_rank: int) -> bool:
|
||||
"""Whether to skip FA groups for this P rank (mamba-only transfer)."""
|
||||
return p_rank not in self._fa_target_set
|
||||
|
||||
def fa_head_slot(self, p_rank: int) -> int:
|
||||
"""Index into D's FA block for this P rank's head data.
|
||||
|
||||
For P ranks in fa_read_targets, returns 0, 1, ..., reads-1.
|
||||
For P ranks NOT in fa_read_targets (replicated duplicates),
|
||||
returns the slot of the matching FA target with the same head.
|
||||
"""
|
||||
if p_rank in self._fa_target_index:
|
||||
return self._fa_target_index[p_rank]
|
||||
# Duplicate head: find which fa_target has the same physical head
|
||||
p_head = _physical_head_range(self.p_tp, self.K, p_rank)
|
||||
for target in self.fa_read_targets:
|
||||
t_head = _physical_head_range(self.p_tp, self.K, target)
|
||||
if _range_overlap(p_head, t_head):
|
||||
return self._fa_target_index[target]
|
||||
return 0 # fallback
|
||||
|
||||
def fa_rank_offset(self, remote_kv_block_len: int) -> int:
|
||||
"""Byte offset into P's FA block for this D rank.
|
||||
|
||||
When D is replicated (D_TP > K), multiple D ranks share a head.
|
||||
Computes offset *relative to the target P rank's first head*
|
||||
so it works regardless of how many heads P has.
|
||||
When neither side replicates, falls back to tp_rank % tp_ratio.
|
||||
Returns 0 when D does not index into P's block.
|
||||
"""
|
||||
if self.use_mla or self.tp_ratio <= 0:
|
||||
return 0
|
||||
if self.is_d_replicated:
|
||||
d_head = self.d_rank * self.K // self.d_tp
|
||||
p_rank = self.fa_read_targets[0]
|
||||
p_start = p_rank * self.K // self.p_tp
|
||||
return (d_head - p_start) * remote_kv_block_len
|
||||
return self.d_rank % self.tp_ratio * remote_kv_block_len
|
||||
|
||||
@property
|
||||
def needs_split_handles(self) -> bool:
|
||||
"""Whether per-P-rank split handles are needed.
|
||||
|
||||
True when FA and mamba have different read counts, requiring
|
||||
different splitting factors in the local handle.
|
||||
"""
|
||||
return self.tp_ratio < 0 and not self.use_mla and len(self.transfer_targets) > 1
|
||||
|
||||
def compute_split_handle_data(
|
||||
self,
|
||||
src_blocks_data: list[tuple[int, int, int]],
|
||||
num_fa_descs: int,
|
||||
abs_tp: int,
|
||||
) -> list[list[tuple[int, int, int]]]:
|
||||
"""Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles.
|
||||
|
||||
FA descriptors (indices < num_fa_descs) are sliced by
|
||||
``physical_fa_num_reads``; mamba descriptors are sliced uniformly
|
||||
by ``abs_tp``.
|
||||
|
||||
Returns one list of triples per transfer target.
|
||||
"""
|
||||
all_handle_data: list[list[tuple[int, int, int]]] = []
|
||||
for p_idx, p_rank in enumerate(self.transfer_targets):
|
||||
handle_data: list[tuple[int, int, int]] = []
|
||||
skip_fa = self.should_skip_fa(p_rank)
|
||||
fa_slot = self.fa_head_slot(p_rank) if not skip_fa else 0
|
||||
|
||||
for j, (addr, local_len, tp) in enumerate(src_blocks_data):
|
||||
if j < num_fa_descs:
|
||||
assert self.physical_fa_num_reads >= 1
|
||||
fa_chunk = local_len // self.physical_fa_num_reads
|
||||
handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, tp))
|
||||
else:
|
||||
mamba_chunk = local_len // abs_tp
|
||||
handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, tp))
|
||||
all_handle_data.append(handle_data)
|
||||
return all_handle_data
|
||||
|
||||
def filter_block_ids_for_rank(
|
||||
self,
|
||||
remote_rank: int,
|
||||
local_ids: BlockIds,
|
||||
remote_ids: BlockIds,
|
||||
is_mamba_group: list[bool],
|
||||
) -> tuple[BlockIds, BlockIds]:
|
||||
"""Zero out FA groups for P ranks outside fa_read_targets.
|
||||
|
||||
Returns (filtered_local_ids, filtered_remote_ids). When the
|
||||
remote rank carries FA data for this D rank, returns the inputs
|
||||
unchanged.
|
||||
"""
|
||||
if not self.should_skip_fa(remote_rank):
|
||||
return local_ids, remote_ids
|
||||
num_groups = len(local_ids)
|
||||
filtered_local: list[list[int]] = [
|
||||
[] if not is_mamba_group[g] else local_ids[g] for g in range(num_groups)
|
||||
]
|
||||
filtered_remote: list[list[int]] = [
|
||||
[] if not is_mamba_group[g] else remote_ids[g] for g in range(num_groups)
|
||||
]
|
||||
return filtered_local, filtered_remote
|
||||
|
||||
def describe(self) -> str:
|
||||
"""One-line summary for logging."""
|
||||
return (
|
||||
f"HeteroTPTransferConfig("
|
||||
f"tp_ratio={self.tp_ratio}, K={self.K}, "
|
||||
f"d_tp={self.d_tp}, p_tp={self.p_tp}, d_rank={self.d_rank}, "
|
||||
f"physical_fa_reads={self.physical_fa_num_reads}, "
|
||||
f"mamba_reads={self.mamba_num_reads}, "
|
||||
f"fa_targets={self.fa_read_targets}, "
|
||||
f"transfer_targets={self.transfer_targets}, "
|
||||
f"fa_entry_size={self.fa_entry_size}, "
|
||||
f"d_block_len={self.d_block_len}, p_block_len={self.p_block_len})"
|
||||
)
|
||||
|
||||
|
||||
def get_current_attn_backends(
|
||||
vllm_config: VllmConfig, layer_names: list[str] | None = None
|
||||
) -> list[type[AttentionBackend]]:
|
||||
@@ -559,3 +891,50 @@ def get_current_attn_backend(
|
||||
) -> type[AttentionBackend]:
|
||||
"""Get the first attention backend for the given layers."""
|
||||
return get_current_attn_backends(vllm_config, layer_names)[0]
|
||||
|
||||
|
||||
# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig
|
||||
# into a single engine-agnostic TransferTopology class.
|
||||
# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data.
|
||||
#
|
||||
# @dataclass
|
||||
# class EngineTransferInfo:
|
||||
# """Per-remote-engine transfer state, computed at handshake."""
|
||||
# p_tp: int
|
||||
# tp_ratio: int
|
||||
# p_block_len: int
|
||||
# block_size: int
|
||||
# # Mamba-specific (None for non-mamba models)
|
||||
# fa_read_targets: list[int] | None = None
|
||||
# transfer_targets: list[int] | None = None
|
||||
# physical_fa_num_reads: int | None = None
|
||||
# mamba_num_reads: int | None = None
|
||||
# fa_entry_size: int | None = None
|
||||
#
|
||||
# class TransferTopology:
|
||||
# """Single source of truth for TP topology + transfer sizing."""
|
||||
# # Shared (set once at init, replaces duplicate fields)
|
||||
# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank
|
||||
# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp
|
||||
# total_num_kv_heads: int # == HeteroTP.K
|
||||
# is_mla: bool # == HeteroTP.use_mla
|
||||
# is_mamba: bool
|
||||
# is_blocks_first: bool # == HeteroTP.is_blocks_first
|
||||
# d_block_len: int
|
||||
#
|
||||
# # Per-engine (populated via register_engine() at handshake)
|
||||
# _engines: dict[EngineId, EngineTransferInfo]
|
||||
#
|
||||
# def register_engine(self, engine_id, p_tp, p_block_len, ...): ...
|
||||
#
|
||||
# # General (from TpKVTopology)
|
||||
# def tp_ratio(self, engine_id) -> int: ...
|
||||
# def target_remote_ranks(self, engine_id) -> list[int]: ...
|
||||
# def is_kv_replicated(self, engine_id) -> bool: ...
|
||||
#
|
||||
# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba)
|
||||
# def fa_rank_offset(self, engine_id, block_len) -> int: ...
|
||||
# def physical_fa_num_reads(self, engine_id) -> int: ...
|
||||
# def transfer_targets(self, engine_id) -> list[int]: ...
|
||||
# def should_skip_fa(self, engine_id, p_rank) -> bool: ...
|
||||
# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ...
|
||||
|
||||
@@ -25,6 +25,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
BlockIds,
|
||||
EngineId,
|
||||
HeteroTPTransferConfig,
|
||||
TpKVTopology,
|
||||
get_current_attn_backend,
|
||||
get_current_attn_backends,
|
||||
@@ -47,12 +48,18 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
|
||||
MambaConvSplitInfo,
|
||||
compute_mamba_phys_ratio,
|
||||
derive_mamba_conv_split,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
||||
@@ -1038,7 +1045,7 @@ class NixlConnectorWorker:
|
||||
}
|
||||
self.hma_group_size = len(kv_cache_config.kv_cache_tensors)
|
||||
|
||||
# Mamba metadata
|
||||
# ---- Mamba model state (derived from model config) ----
|
||||
self._is_mamba_group = [
|
||||
isinstance(group.kv_cache_spec, MambaSpec)
|
||||
for group in kv_cache_config.kv_cache_groups
|
||||
@@ -1065,6 +1072,17 @@ class NixlConnectorWorker:
|
||||
ssm_shape.numel() * ssm_nbytes,
|
||||
)
|
||||
self._mamba_ssm_size = mamba_ssm_size
|
||||
# Conv state sub-projection decomposition (None when no Mamba).
|
||||
# The 3-read transfer requires DS (dim, state_len) conv layout so
|
||||
# that x/B/C sub-projections are contiguous in memory.
|
||||
self._conv_decomp: MambaConvSplitInfo | None = None
|
||||
if self._has_mamba:
|
||||
assert is_conv_state_dim_first(), (
|
||||
"3-read Mamba conv transfer requires DS conv state layout. "
|
||||
"Set VLLM_SSM_CONV_STATE_LAYOUT=DS"
|
||||
)
|
||||
local_tp = vllm_config.parallel_config.tensor_parallel_size
|
||||
self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp)
|
||||
|
||||
# Agent.
|
||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||
@@ -1175,6 +1193,16 @@ class NixlConnectorWorker:
|
||||
self.dst_num_blocks: dict[EngineId, int] = {}
|
||||
self._registered_descs: list[Any] = []
|
||||
|
||||
# ---- Mamba-HMA per-engine state (only used when self._has_mamba) ----
|
||||
# Per-engine transfer config (source of truth for FA/mamba sizing).
|
||||
self._transfer_configs: dict[str, HeteroTPTransferConfig] = {}
|
||||
# NOTE (ZhanqiuHu): _mamba_phys_ratio MUST be per-engine.
|
||||
# compute_mamba_phys_ratio = ceil((conv_bytes + ssm_bytes) / block_len)
|
||||
# where conv/ssm bytes are per-TP-rank (dimension-sharded). With
|
||||
# heterogeneous TP the per-rank sizes differ, so the ratio differs:
|
||||
# e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261.
|
||||
self._mamba_phys_ratio: dict[EngineId, int] = {}
|
||||
|
||||
# In progress transfers.
|
||||
# [req_id -> list[handle]]
|
||||
self._recving_metadata: dict[ReqId, ReqMeta] = {}
|
||||
@@ -1701,8 +1729,7 @@ class NixlConnectorWorker:
|
||||
# then duplicate it logically to be able to index SSM/Conv separately.
|
||||
self.num_regions *= 2
|
||||
|
||||
# TODO (NickLucche) Adapt to different descs views (engine_id->tp_rank) to
|
||||
# support heterogeneous TP.
|
||||
# Total local FA descriptors (boundary between FA and mamba descs).
|
||||
self.num_descs = self.num_regions * self.num_blocks
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type)
|
||||
@@ -1715,6 +1742,9 @@ class NixlConnectorWorker:
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
if self._has_mamba:
|
||||
self._mamba_phys_ratio[self.engine_id] = (
|
||||
self._physical_blocks_per_logical_kv_block
|
||||
)
|
||||
logger.info(
|
||||
"Hybrid SSM registration: num_blocks=%s, "
|
||||
"logical_num_blocks=%s, ratio=%s, num_regions=%s, "
|
||||
@@ -1755,6 +1785,149 @@ class NixlConnectorWorker:
|
||||
agent_metadata_bytes=encoder.encode(agent_metadata),
|
||||
)
|
||||
|
||||
def _build_mamba_local(
|
||||
self,
|
||||
base_addresses: list[int],
|
||||
block_size_ratio: int,
|
||||
) -> list[tuple[int, int, int]]:
|
||||
"""Build 4 desc regions (x, B, C, ssm) per layer for local mamba
|
||||
blocks, enabling the 3-read transfer with DS conv layout."""
|
||||
assert block_size_ratio == 1, (
|
||||
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
|
||||
f"Got block_size_ratio={block_size_ratio}."
|
||||
)
|
||||
assert self._conv_decomp is not None
|
||||
conv_offsets = self._conv_decomp.local_conv_offsets
|
||||
conv_size, ssm_size = self._mamba_ssm_size
|
||||
num_blocks = self._logical_num_blocks * block_size_ratio
|
||||
phys_ratio = self._physical_blocks_per_logical_kv_block
|
||||
|
||||
result: list[tuple[int, int, int]] = []
|
||||
for i, base_addr in enumerate(base_addresses):
|
||||
page_stride = self.block_len_per_layer[i] // block_size_ratio * phys_ratio
|
||||
for off, sz in conv_offsets:
|
||||
for blk in range(num_blocks):
|
||||
result.append(
|
||||
(base_addr + blk * page_stride + off, sz, self.device_id)
|
||||
)
|
||||
# SSM temporal state follows the conv state.
|
||||
for blk in range(num_blocks):
|
||||
result.append(
|
||||
(
|
||||
base_addr + blk * page_stride + conv_size,
|
||||
ssm_size,
|
||||
self.device_id,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def _build_fa_remote_for_mamba(
|
||||
self,
|
||||
nixl_agent_meta: NixlAgentMetadata,
|
||||
transfer_cfg: HeteroTPTransferConfig,
|
||||
block_size_ratio: int,
|
||||
kv_topo: TpKVTopology,
|
||||
) -> list[tuple[int, int, int]]:
|
||||
"""Build remote FA descriptors for mamba models.
|
||||
|
||||
Uses transfer_cfg for GQA-aware FA divisor and head-based rank offset
|
||||
instead of the standard uniform tp_ratio split.
|
||||
"""
|
||||
assert block_size_ratio == 1, (
|
||||
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
|
||||
f"Got block_size_ratio={block_size_ratio}."
|
||||
)
|
||||
# TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA
|
||||
# hetero-TP logic stabilizes.
|
||||
tp_ratio = transfer_cfg.tp_ratio
|
||||
result: list[tuple[int, int, int]] = []
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
local_block_len = self.get_backend_aware_kv_block_len(
|
||||
layer_idx=i, first_split=True, mamba_view=False
|
||||
)
|
||||
remote_kv_block_len = local_block_len // block_size_ratio
|
||||
if block_size_ratio > 1:
|
||||
local_block_len = remote_kv_block_len
|
||||
|
||||
if tp_ratio < 0 and not self.use_mla:
|
||||
local_block_len = local_block_len // transfer_cfg.physical_fa_num_reads
|
||||
|
||||
rank_offset = transfer_cfg.fa_rank_offset(remote_kv_block_len)
|
||||
|
||||
num_blocks = nixl_agent_meta.num_blocks
|
||||
page_size = nixl_agent_meta.block_lens[i]
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * page_size
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
result.append((addr, local_block_len, nixl_agent_meta.device_id))
|
||||
|
||||
if kv_topo.is_kv_layout_blocks_first:
|
||||
second_split = self.get_backend_aware_kv_block_len(
|
||||
layer_idx=i, first_split=False, mamba_view=False
|
||||
)
|
||||
if tp_ratio < 0 and not self.use_mla:
|
||||
second_split = second_split // transfer_cfg.physical_fa_num_reads
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * page_size
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
result.append((v_addr, second_split, nixl_agent_meta.device_id))
|
||||
return result
|
||||
|
||||
def _build_mamba_remote(
|
||||
self,
|
||||
nixl_agent_meta: NixlAgentMetadata,
|
||||
tp_ratio: int,
|
||||
) -> list[tuple[int, int, int]]:
|
||||
"""Build 4 remote desc regions (x, B, C, ssm) per layer for
|
||||
the 3-read transfer. For hetero-TP, each D rank reads only its
|
||||
sub-projection slice from the P rank."""
|
||||
assert self._conv_decomp is not None
|
||||
effective_ratio = max(tp_ratio, 1)
|
||||
# Mamba conv state is always TP-sharded, even when attention KV
|
||||
# is replicated (num_kv_heads < tp_size).
|
||||
local_offset = self.tp_rank % effective_ratio
|
||||
conv_size_remote = nixl_agent_meta.ssm_sizes[0]
|
||||
|
||||
if tp_ratio >= 1:
|
||||
# D_TP >= P_TP: P page is larger, D reads its slice.
|
||||
conv_offsets = self._conv_decomp.remote_conv_offsets(
|
||||
local_offset, effective_ratio
|
||||
)
|
||||
ssm_read_size = self._mamba_ssm_size[1]
|
||||
else:
|
||||
# NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages
|
||||
# are smaller than D's. self._conv_decomp has D-sized dimensions,
|
||||
# but we need P-sized offsets. Scale down by |tp_ratio|.
|
||||
abs_ratio = -tp_ratio
|
||||
xb_p = self._conv_decomp.x_bytes // abs_ratio
|
||||
bb_p = self._conv_decomp.b_bytes // abs_ratio
|
||||
conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)]
|
||||
ssm_read_size = nixl_agent_meta.ssm_sizes[1]
|
||||
|
||||
remote_ratio = self._mamba_phys_ratio[nixl_agent_meta.engine_id]
|
||||
num_blocks = nixl_agent_meta.num_blocks // remote_ratio
|
||||
device_id = nixl_agent_meta.device_id
|
||||
|
||||
result: list[tuple[int, int, int]] = []
|
||||
# NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case
|
||||
# block lengths vary across layers (e.g. MLA).
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
page_stride = nixl_agent_meta.block_lens[i] * remote_ratio
|
||||
for off, sz in conv_offsets:
|
||||
for blk in range(num_blocks):
|
||||
result.append((base_addr + blk * page_stride + off, sz, device_id))
|
||||
# SSM temporal state is also TP-sharded on the heads dimension.
|
||||
for blk in range(num_blocks):
|
||||
ssm_addr = (
|
||||
base_addr
|
||||
+ blk * page_stride
|
||||
+ conv_size_remote
|
||||
+ local_offset * ssm_read_size
|
||||
)
|
||||
result.append((ssm_addr, ssm_read_size, device_id))
|
||||
return result
|
||||
|
||||
def register_local_xfer_handler(
|
||||
self,
|
||||
block_size: int,
|
||||
@@ -1823,13 +1996,22 @@ class NixlConnectorWorker:
|
||||
self.device_id,
|
||||
)
|
||||
|
||||
# NOTE (ZhanqiuHu): mamba=True path in register_blocks is not used
|
||||
# right now — we use _build_mamba_local instead for the 3-read
|
||||
# approach. However, we might still need this as a fallback for homogeneous TP.
|
||||
register_blocks(blocks_data, mamba=False)
|
||||
if self._has_mamba:
|
||||
assert self.num_descs == len(blocks_data)
|
||||
logger.debug(
|
||||
"Registering additional %s local Mamba blocks", len(blocks_data)
|
||||
# TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read split is
|
||||
# unnecessary — a single conv desc per block suffices. Consider
|
||||
# adding a fast path that falls back to the standard 2-region
|
||||
# registration (register_blocks mamba=True) when no hetero-TP
|
||||
# remote has been seen. Currently we always register 4 regions
|
||||
# because local descs are created before knowing the remote TP.
|
||||
logger.debug("Registering local Mamba descriptors (4 regions/layer)")
|
||||
blocks_data.extend(
|
||||
self._build_mamba_local(local_base_addresses, block_size_ratio)
|
||||
)
|
||||
register_blocks(blocks_data, mamba=True)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
@@ -1880,6 +2062,9 @@ class NixlConnectorWorker:
|
||||
|
||||
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
|
||||
so that the whole cache is shared by "tp_ratio" D TP workers.
|
||||
|
||||
For Mamba hetero-TP, both tp_ratio > 0 (D_TP > P_TP) and
|
||||
tp_ratio < 0 (P_TP > D_TP) are supported by the 3-read transfer.
|
||||
""" # noqa: E501
|
||||
engine_id = nixl_agent_meta.engine_id
|
||||
# TODO re-evaluate refreshing for scaling/recovery
|
||||
@@ -1915,6 +2100,10 @@ class NixlConnectorWorker:
|
||||
|
||||
if engine_id not in self.dst_num_blocks:
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
if self._has_mamba:
|
||||
self._mamba_phys_ratio[engine_id] = compute_mamba_phys_ratio(
|
||||
nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0]
|
||||
)
|
||||
|
||||
# Keep track of remote agent kv caches base addresses.
|
||||
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
|
||||
@@ -1931,6 +2120,21 @@ class NixlConnectorWorker:
|
||||
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
|
||||
)
|
||||
|
||||
# Create transfer config (single source of truth for descriptor sizes).
|
||||
if self._has_mamba and engine_id not in self._transfer_configs:
|
||||
self._transfer_configs[engine_id] = HeteroTPTransferConfig(
|
||||
tp_ratio=tp_ratio,
|
||||
K=kv_topo.total_num_kv_heads,
|
||||
d_tp=self.world_size,
|
||||
p_tp=remote_tp_size,
|
||||
d_rank=self.tp_rank,
|
||||
use_mla=self.use_mla,
|
||||
d_block_len=self.block_len_per_layer[0],
|
||||
p_block_len=nixl_agent_meta.block_lens[0],
|
||||
is_blocks_first=kv_topo.is_kv_layout_blocks_first,
|
||||
)
|
||||
logger.info("Created %s", self._transfer_configs[engine_id].describe())
|
||||
|
||||
logger.debug(
|
||||
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
|
||||
engine_id,
|
||||
@@ -1947,21 +2151,48 @@ class NixlConnectorWorker:
|
||||
# Remote tp_size > local tp_size: read from multiple remote ranks.
|
||||
# Logically "split" own regions into |tp_ratio| chunks. Mind that
|
||||
# we only do this once per remote tp_size (replica-friendly).
|
||||
abs_tp = -tp_ratio
|
||||
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
|
||||
for i in range(-tp_ratio):
|
||||
blocks_data = []
|
||||
for memory_region in self.src_blocks_data:
|
||||
addr, local_block_len, own_tp_rank = memory_region
|
||||
# Computing block len layer by layer allows for different
|
||||
# block sizes to be used.
|
||||
remote_block_len = local_block_len // (-tp_ratio)
|
||||
addr = addr + i * remote_block_len
|
||||
blocks_data.append((addr, remote_block_len, own_tp_rank))
|
||||
descs = self.nixl_wrapper.get_xfer_descs(
|
||||
blocks_data, self.nixl_memory_type
|
||||
)
|
||||
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
|
||||
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
|
||||
|
||||
if self._has_mamba:
|
||||
transfer_cfg = self._transfer_configs.get(engine_id)
|
||||
assert transfer_cfg is not None
|
||||
if transfer_cfg.needs_split_handles:
|
||||
# Mamba-HMA: FA and Mamba use different split factors.
|
||||
for handle_data in transfer_cfg.compute_split_handle_data(
|
||||
self.src_blocks_data, self.num_descs, abs_tp
|
||||
):
|
||||
descs = self.nixl_wrapper.get_xfer_descs(
|
||||
handle_data, self.nixl_memory_type
|
||||
)
|
||||
handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs
|
||||
)
|
||||
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
|
||||
|
||||
logger.info(
|
||||
"Mamba-HMA split handles: targets=%s, fa_reads=%s, "
|
||||
"fa_entry=%s, mamba_reads=%s, num_descs=%s",
|
||||
transfer_cfg.transfer_targets,
|
||||
transfer_cfg.physical_fa_num_reads,
|
||||
transfer_cfg.fa_entry_size,
|
||||
transfer_cfg.mamba_num_reads,
|
||||
self.num_descs,
|
||||
)
|
||||
else:
|
||||
# Original path: uniform divide by abs_tp (non-Mamba-HMA).
|
||||
for i in range(abs_tp):
|
||||
blocks_data = []
|
||||
for memory_region in self.src_blocks_data:
|
||||
addr, local_block_len, own_tp_rank = memory_region
|
||||
remote_block_len = local_block_len // abs_tp
|
||||
addr = addr + i * remote_block_len
|
||||
blocks_data.append((addr, remote_block_len, own_tp_rank))
|
||||
descs = self.nixl_wrapper.get_xfer_descs(
|
||||
blocks_data, self.nixl_memory_type
|
||||
)
|
||||
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
|
||||
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
|
||||
|
||||
### Register remote agent memory regions
|
||||
blocks_data = []
|
||||
@@ -2044,13 +2275,33 @@ class NixlConnectorWorker:
|
||||
self.tp_rank,
|
||||
)
|
||||
|
||||
register_remote_blocks(blocks_data, mamba=False)
|
||||
if self._has_mamba:
|
||||
# Create extra descs for the Mamba "view" of the same KV cache tensors.
|
||||
# Mamba-HMA: separate FA registration with GQA-aware sizing,
|
||||
# plus mamba 3-read registration for the Mamba "view" of the
|
||||
# same KV cache tensors.
|
||||
logger.debug(
|
||||
"Registering additional %s remote Mamba blocks", len(blocks_data)
|
||||
"Registering remote Mamba blocks for engine %s rank %s",
|
||||
engine_id,
|
||||
remote_tp_rank,
|
||||
)
|
||||
register_remote_blocks(blocks_data, mamba=True)
|
||||
transfer_cfg = self._transfer_configs.get(engine_id)
|
||||
assert transfer_cfg is not None
|
||||
blocks_data.extend(
|
||||
self._build_fa_remote_for_mamba(
|
||||
nixl_agent_meta,
|
||||
transfer_cfg,
|
||||
block_size_ratio,
|
||||
kv_topo,
|
||||
)
|
||||
)
|
||||
blocks_data.extend(
|
||||
self._build_mamba_remote(
|
||||
nixl_agent_meta,
|
||||
tp_ratio,
|
||||
)
|
||||
)
|
||||
else:
|
||||
register_remote_blocks(blocks_data, mamba=False)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
@@ -2083,17 +2334,17 @@ class NixlConnectorWorker:
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
||||
remote_engine_id
|
||||
)
|
||||
# Num kv_heads > tp_size and P TP > D TP case, not supported
|
||||
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
|
||||
# num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba.
|
||||
# Mamba models can have replicated FA KV with tp_ratio < 0.
|
||||
if not self._has_mamba:
|
||||
assert not (
|
||||
tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)
|
||||
)
|
||||
|
||||
if self._is_hma_required:
|
||||
assert block_size_ratio == 1, (
|
||||
"HMA does not support different remote block size yet"
|
||||
)
|
||||
# Mamba additional constraints
|
||||
if self._has_mamba:
|
||||
assert tp_ratio == 1, "Mamba does not support heterogeneous TP yet"
|
||||
|
||||
kv_cache_layout = (
|
||||
self.kv_cache_layout
|
||||
if not self.use_host_buffer
|
||||
@@ -2138,11 +2389,14 @@ class NixlConnectorWorker:
|
||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
|
||||
# With replicated KV cache, only the number of blocks can differ.
|
||||
for i in range(len(self.block_len_per_layer)):
|
||||
assert (
|
||||
self.block_len_per_layer[i] // block_size_ratio
|
||||
== nixl_agent_meta.block_lens[i]
|
||||
), "KV cache sizes must match between P and D when replicated"
|
||||
# TODO (ZhanqiuHu): For mamba models, validate FA and mamba
|
||||
# block_lens separately.
|
||||
if not self._has_mamba:
|
||||
for i in range(len(self.block_len_per_layer)):
|
||||
assert (
|
||||
self.block_len_per_layer[i] // block_size_ratio
|
||||
== nixl_agent_meta.block_lens[i]
|
||||
), "KV cache sizes must match between P and D when replicated"
|
||||
else:
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
for block_len in nixl_agent_meta.block_lens:
|
||||
@@ -2150,25 +2404,31 @@ class NixlConnectorWorker:
|
||||
"All remote layers must have the same block size"
|
||||
)
|
||||
|
||||
if tp_ratio > 0:
|
||||
# Remote tp is smaller: remote block_len size is bigger
|
||||
assert (
|
||||
remote_block_len
|
||||
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
|
||||
), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
|
||||
) # noqa: E501
|
||||
else:
|
||||
assert block_size_ratio == 1, (
|
||||
"Different local/remote block sizes are not supported when"
|
||||
" P TP > D TP."
|
||||
)
|
||||
# Remote tp is bigger: remote block_len size is smaller
|
||||
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
|
||||
) # noqa: E501
|
||||
# HMA hybrid models (mamba+attention) pad block_len to
|
||||
# max(attn_page, mamba_page), so the linear tp_ratio scaling
|
||||
# assumption only holds for pure-attention models.
|
||||
if not self._has_mamba:
|
||||
if tp_ratio > 0:
|
||||
assert (
|
||||
remote_block_len
|
||||
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
|
||||
), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N,"
|
||||
" local_kv_heads*tp_ratio, page_size, head_dim] and "
|
||||
"same dtype."
|
||||
)
|
||||
else:
|
||||
assert block_size_ratio == 1, (
|
||||
"Different local/remote block sizes are not supported"
|
||||
" when P TP > D TP."
|
||||
)
|
||||
assert remote_block_len == self.block_len_per_layer[0] // (
|
||||
-tp_ratio
|
||||
), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N,"
|
||||
" local_kv_heads/tp_ratio, page_size, head_dim] and "
|
||||
"same dtype."
|
||||
)
|
||||
|
||||
# TP workers that handhshake with same remote have same #blocks.
|
||||
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||
@@ -2471,9 +2731,8 @@ class NixlConnectorWorker:
|
||||
meta.local_block_ids
|
||||
)
|
||||
assert meta.remote is not None
|
||||
meta.remote.block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.remote.block_ids
|
||||
)
|
||||
# Remote block IDs are kept logical here; expanded in
|
||||
# _read_blocks_for_req using the remote engine's phys ratio.
|
||||
remote_engine_id = meta.remote.engine_id
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine %s. "
|
||||
@@ -2525,6 +2784,13 @@ class NixlConnectorWorker:
|
||||
meta.remote.engine_id
|
||||
)
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
|
||||
|
||||
if self._has_mamba:
|
||||
# Expand remote logical → kernel block IDs.
|
||||
meta.remote.block_ids = self._logical_to_remote_kernel_block_ids(
|
||||
meta.remote.block_ids,
|
||||
self._mamba_phys_ratio[meta.remote.engine_id],
|
||||
)
|
||||
# D may have to perform multiple reads from different remote ranks.
|
||||
for i, remote_rank in enumerate(remote_ranks):
|
||||
if self.use_mla and tp_ratio < 0 and i > 0:
|
||||
@@ -2558,12 +2824,26 @@ class NixlConnectorWorker:
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
|
||||
remote_rank
|
||||
]
|
||||
|
||||
local_ids: BlockIds = meta.local_physical_block_ids
|
||||
remote_ids: BlockIds = meta.remote.block_ids
|
||||
if self._has_mamba:
|
||||
# Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets.
|
||||
transfer_cfg = self._transfer_configs.get(meta.remote.engine_id)
|
||||
assert transfer_cfg is not None
|
||||
local_ids, remote_ids = transfer_cfg.filter_block_ids_for_rank(
|
||||
remote_rank,
|
||||
local_ids,
|
||||
remote_ids,
|
||||
self._is_mamba_group,
|
||||
)
|
||||
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote.engine_id,
|
||||
remote_request_id=meta.remote.request_id,
|
||||
local_block_ids=meta.local_physical_block_ids,
|
||||
remote_block_ids=meta.remote.block_ids,
|
||||
local_block_ids=local_ids,
|
||||
remote_block_ids=remote_ids,
|
||||
remote_rank=remote_rank,
|
||||
local_xfer_side_handle=local_xfer_side_handle,
|
||||
remote_xfer_side_handle=remote_xfer_side_handle,
|
||||
@@ -2663,9 +2943,12 @@ class NixlConnectorWorker:
|
||||
for i, remote_group in enumerate(remote_block_ids):
|
||||
num_remote_blocks = len(remote_group)
|
||||
num_local_blocks = len(local_block_ids[i])
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
if not self._is_mamba_group[i]:
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
# Skip mamba groups — their blocks represent full state (conv+ssm),
|
||||
# not per-token data, so trimming would corrupt the transfer.
|
||||
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
|
||||
remote_block_ids[i] = remote_group[-num_local_blocks:]
|
||||
|
||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||
@@ -2781,16 +3064,22 @@ class NixlConnectorWorker:
|
||||
# This is like having two "low-level views" of the same storage.
|
||||
# `num_fa_descs` offset must be computed per-engine since P and D can
|
||||
# have different num_blocks (and thus different FA descs counts).
|
||||
ratio = self._physical_blocks_per_logical_kv_block
|
||||
# SSM may register fewer num_blocks than FA
|
||||
ratio = self._mamba_phys_ratio[engine_id]
|
||||
logical_blocks = num_blocks // ratio
|
||||
num_fa_descs = self.num_regions * num_blocks
|
||||
# 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm).
|
||||
mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None]
|
||||
all_descs = []
|
||||
for i, group in enumerate(block_ids):
|
||||
stride = logical_blocks if self._is_mamba_group[i] else num_blocks
|
||||
group_arr = np.asarray(group)[None, :]
|
||||
offset = num_fa_descs if self._is_mamba_group[i] else 0
|
||||
all_descs.append((region_ids * stride + group_arr + offset).flatten())
|
||||
if self._is_mamba_group[i]:
|
||||
all_descs.append(
|
||||
(
|
||||
mamba_region_ids * logical_blocks + group_arr + num_fa_descs
|
||||
).flatten()
|
||||
)
|
||||
else:
|
||||
all_descs.append((region_ids * num_blocks + group_arr).flatten())
|
||||
return np.concatenate(all_descs)
|
||||
|
||||
def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds:
|
||||
@@ -2818,6 +3107,36 @@ class NixlConnectorWorker:
|
||||
for i, group in enumerate(block_ids)
|
||||
]
|
||||
|
||||
def _logical_to_remote_kernel_block_ids(
|
||||
self, block_ids: BlockIds, remote_ratio: int
|
||||
) -> BlockIds:
|
||||
"""Map logical block IDs to physical kernel block IDs on the remote.
|
||||
|
||||
Args:
|
||||
block_ids: per-group lists of logical block IDs.
|
||||
remote_ratio: remote engine's physical blocks per logical block.
|
||||
|
||||
Returns:
|
||||
Same structure with FA groups expanded (each logical block L
|
||||
becomes kernel blocks [L*remote_ratio .. L*remote_ratio +
|
||||
local_ratio - 1]). Mamba groups are passed through unchanged.
|
||||
"""
|
||||
local_ratio = self._physical_blocks_per_logical_kv_block
|
||||
if remote_ratio == 1:
|
||||
return block_ids
|
||||
local_arange = np.arange(local_ratio).reshape(1, -1)
|
||||
group_specs = self.kv_cache_config.kv_cache_groups
|
||||
result: list[list[int]] = []
|
||||
for i, group in enumerate(block_ids):
|
||||
if not isinstance(group_specs[i].kv_cache_spec, MambaSpec):
|
||||
arr = np.array(group).reshape(-1, 1)
|
||||
expanded = (arr * remote_ratio + local_arange).flatten()
|
||||
result.append(expanded.tolist())
|
||||
else:
|
||||
# Mamba blocks are 1:1 logical-to-physical (no expansion).
|
||||
result.append(group)
|
||||
return result
|
||||
|
||||
def get_backend_aware_kv_block_len(
|
||||
self, layer_idx: int, first_split: bool = True, mamba_view: bool = False
|
||||
) -> int:
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Mamba conv-state sub-projection decomposition for the 3-read transfer.
|
||||
|
||||
With DS conv state layout (dim, state_len), x/B/C sub-projections are
|
||||
contiguous in memory. Each D rank reads its x, B, C slices via 3
|
||||
separate RDMA transfers — no P-side permutation needed.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first
|
||||
from vllm.v1.kv_cache_interface import MambaSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MambaConvSplitInfo:
|
||||
"""Per-rank byte sizes of x, B, C sub-projections in the Mamba conv state.
|
||||
|
||||
Used by both P and D sides for NIXL descriptor registration.
|
||||
All fields are LOCAL to this engine's TP (already divided by TP size).
|
||||
|
||||
DS memory layout within one page (contiguous in memory):
|
||||
|--- x (x_local * conv_rows) ---|- B (b_local * conv_rows) -|- C -|
|
||||
"""
|
||||
|
||||
conv_rows: int # conv_kernel - 1 (typically 3)
|
||||
x_local: int # intermediate_size / TP (columns for x)
|
||||
b_local: int # groups_ss / TP (columns for B; C is same size)
|
||||
conv_dtype_size: int # bytes per element (e.g. 2 for float16)
|
||||
|
||||
@property
|
||||
def conv_dim_local(self) -> int:
|
||||
"""Total conv columns per rank: x + B + C."""
|
||||
return self.x_local + 2 * self.b_local
|
||||
|
||||
@property
|
||||
def x_bytes(self) -> int:
|
||||
"""Byte size of the x sub-projection for one rank."""
|
||||
return self.x_local * self.conv_rows * self.conv_dtype_size
|
||||
|
||||
@property
|
||||
def b_bytes(self) -> int:
|
||||
"""Byte size of the B (or C) sub-projection for one rank."""
|
||||
return self.b_local * self.conv_rows * self.conv_dtype_size
|
||||
|
||||
@property
|
||||
def local_conv_offsets(self) -> list[tuple[int, int]]:
|
||||
"""(byte_offset, byte_size) of x, B, C within this engine's page.
|
||||
|
||||
Used by both P and D for local descriptor registration.
|
||||
"""
|
||||
xb = self.x_bytes
|
||||
bb = self.b_bytes
|
||||
return [(0, xb), (xb, bb), (xb + bb, bb)]
|
||||
|
||||
def remote_conv_offsets(
|
||||
self, local_rank_offset: int, tp_ratio: int
|
||||
) -> list[tuple[int, int]]:
|
||||
"""(byte_offset, byte_size) of this D rank's x, B, C slice within
|
||||
one P page.
|
||||
|
||||
Used by D side only, during remote descriptor registration.
|
||||
|
||||
Args:
|
||||
local_rank_offset: which slice this D rank reads.
|
||||
tp_ratio > 0: tp_rank % tp_ratio (selects slice of P's page).
|
||||
tp_ratio < 0: always 0 (read P's full page).
|
||||
tp_ratio: effective ratio (>= 1 when D_TP > P_TP, 1 when
|
||||
P_TP > D_TP since each P rank is read in full).
|
||||
"""
|
||||
xb = self.x_bytes
|
||||
bb = self.b_bytes
|
||||
xr = xb * tp_ratio # full remote x section in bytes
|
||||
br = bb * tp_ratio # full remote B section in bytes
|
||||
return [
|
||||
(local_rank_offset * xb, xb),
|
||||
(xr + local_rank_offset * bb, bb),
|
||||
(xr + br + local_rank_offset * bb, bb),
|
||||
]
|
||||
|
||||
|
||||
def derive_mamba_conv_split(
|
||||
mamba_spec: MambaSpec,
|
||||
local_tp: int,
|
||||
) -> MambaConvSplitInfo:
|
||||
"""Derive per-rank x/B/C byte sizes from a MambaSpec.
|
||||
|
||||
Called once at init on both P and D. Decomposes the conv dimension
|
||||
(= intermediate_size + 2 * groups_ss) into its x, B, C parts.
|
||||
|
||||
Args:
|
||||
mamba_spec: MambaSpec whose shapes are:
|
||||
shapes[0] = conv state: (conv_dim_local, conv_rows) in DS layout.
|
||||
shapes[1] = SSM temporal: (local_num_heads, head_dim).
|
||||
local_tp: this engine's tensor-parallel size.
|
||||
|
||||
Returns:
|
||||
MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, and
|
||||
conv_dtype_size.
|
||||
"""
|
||||
if mamba_spec.mamba_type != "mamba2":
|
||||
raise NotImplementedError(
|
||||
f"3-read conv transfer only supports Mamba2 models, "
|
||||
f"got mamba_type={mamba_spec.mamba_type!r}. "
|
||||
f"Mamba1 SSM temporal shape is (intermediate_size // tp, state_size) "
|
||||
f"which cannot be used to reconstruct intermediate_size."
|
||||
)
|
||||
|
||||
conv_shape = mamba_spec.shapes[0]
|
||||
assert len(conv_shape) == 2, f"Expected 2D conv state shape, got {conv_shape}"
|
||||
|
||||
# NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted
|
||||
# in nixl_connector __init__. Use it directly instead of heuristic detection.
|
||||
assert is_conv_state_dim_first(), "3-read requires DS conv state layout"
|
||||
local_conv_dim = conv_shape[0] # DS: (conv_dim_local, conv_rows)
|
||||
conv_rows = conv_shape[1]
|
||||
|
||||
# NOTE (ZhanqiuHu): intermediate_size (= global x dim) is not stored
|
||||
# in MambaSpec, so we reconstruct it from the SSM temporal state shape:
|
||||
# shapes[1] = (local_num_heads, head_dim), already divided by TP.
|
||||
head_dim = mamba_spec.shapes[1][1]
|
||||
local_num_heads = mamba_spec.shapes[1][0]
|
||||
intermediate_size = local_num_heads * local_tp * head_dim
|
||||
|
||||
# NOTE (ZhanqiuHu): global conv dim = intermediate_size + 2 * groups_ss,
|
||||
# where groups_ss is the B (= C) dimension. B and C are always the same
|
||||
# size, so we recover groups_ss from the remainder after subtracting x.
|
||||
remainder = local_conv_dim * local_tp - intermediate_size
|
||||
assert remainder > 0 and remainder % 2 == 0, (
|
||||
f"Conv dim ({local_conv_dim}*tp={local_tp}) doesn't decompose into "
|
||||
f"intermediate_size={intermediate_size} + 2*groups_ss. "
|
||||
f"remainder={remainder}"
|
||||
)
|
||||
groups_ss = remainder // 2
|
||||
|
||||
conv_dtype_size = torch.tensor(
|
||||
[],
|
||||
dtype=mamba_spec.dtypes[0], # type: ignore[misc]
|
||||
).element_size()
|
||||
|
||||
# Divide by TP to get per-rank column counts.
|
||||
return MambaConvSplitInfo(
|
||||
conv_rows=conv_rows,
|
||||
x_local=intermediate_size // local_tp,
|
||||
b_local=groups_ss // local_tp,
|
||||
conv_dtype_size=conv_dtype_size,
|
||||
)
|
||||
|
||||
|
||||
def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int:
|
||||
"""Derive _physical_blocks_per_logical_kv_block from remote metadata.
|
||||
|
||||
The remote engine's ratio is not sent directly in the handshake, so we
|
||||
reconstruct it: total mamba state per logical block / block_len.
|
||||
|
||||
Args:
|
||||
ssm_sizes: (conv_state_bytes, ssm_state_bytes) from NixlAgentMetadata.
|
||||
block_len: the engine's block_len in bytes (from block_lens[0]).
|
||||
"""
|
||||
return math.ceil((ssm_sizes[0] + ssm_sizes[1]) / block_len)
|
||||
Reference in New Issue
Block a user