[NIXL][Mamba][3/N] Heterogeneous TP: 3-read conv state transfer (#37635)

This commit is contained in:
zhanqiuhu
2026-04-06 13:07:02 -04:00
committed by GitHub
parent 93bada494f
commit bfdc0a3a99
5 changed files with 970 additions and 75 deletions

View File

@@ -19,9 +19,9 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
hybrid_ssm_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
sw_attn_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"

View File

@@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
worker._mamba_phys_ratio = {engine_id: 1}
worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
@@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm():
# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# region0: [201, 202], region1: [301, 302]
expected = [3, 5, 103, 105, 201, 202, 301, 302]
# offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm)
# region0: [201, 202], region1: [301, 302],
# region2: [401, 402], region3: [501, 502]
expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
@@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker._mamba_phys_ratio = {engine_id: ratio}
worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800
fa_blocks = [3, 7] # kernel-level block IDs
@@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch():
# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# region0: [801, 802], region1: [901, 902]
expected = [3, 7, 403, 407, 801, 802, 901, 902]
# SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800,
# 4 regions per Mamba layer (x, B, C, ssm)
# region0: [801, 802], region1: [901, 902],
# region2: [1001, 1002], region3: [1101, 1102]
expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
@@ -418,3 +425,29 @@ def test_has_mamba_init(
)
assert scheduler._has_mamba is expected_has_mamba
assert scheduler._is_hma_required is expected_is_hma
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"ssm_sizes,block_len,expected_ratio",
[
# Nemotron 30B TP=1: ceil((36864 + 2097152) / 8192) = 261
((36864, 2097152), 8192, 261),
# Nemotron 30B TP=2: ceil((18432 + 1048576) / 4096) = 261
((18432, 1048576), 4096, 261),
# Nemotron 30B TP=4: ceil((9216 + 524288) / 4096) = 131
((9216, 524288), 4096, 131),
],
)
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_mamba_phys_ratio is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
compute_mamba_phys_ratio,
)
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio