[NIXL][Mamba][3/N] Heterogeneous TP: 3-read conv state transfer (#37635)
This commit is contained in:
@@ -19,9 +19,9 @@ dp_ep_configs=(
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
|
||||
)
|
||||
hybrid_ssm_configs=(
|
||||
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
|
||||
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
|
||||
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
||||
)
|
||||
sw_attn_configs=(
|
||||
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
|
||||
|
||||
@@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm():
|
||||
worker._has_mamba = True
|
||||
worker._is_mamba_group = [False, True]
|
||||
worker._physical_blocks_per_logical_kv_block = 1
|
||||
worker._mamba_phys_ratio = {engine_id: 1}
|
||||
worker.block_len_per_layer = [100]
|
||||
# num_descs = num_regions * num_blocks (no blocks_first doubling)
|
||||
worker.num_descs = 2 * num_blocks
|
||||
|
||||
@@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm():
|
||||
# FA group: stride=num_blocks=100, offset=0
|
||||
# region0: [3, 5], region1: [103, 105]
|
||||
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
|
||||
# offset=num_descs=200
|
||||
# region0: [201, 202], region1: [301, 302]
|
||||
expected = [3, 5, 103, 105, 201, 202, 301, 302]
|
||||
# offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm)
|
||||
# region0: [201, 202], region1: [301, 302],
|
||||
# region2: [401, 402], region3: [501, 502]
|
||||
expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502]
|
||||
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
|
||||
|
||||
|
||||
@@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch():
|
||||
worker._has_mamba = True
|
||||
worker._is_mamba_group = [False, True]
|
||||
worker._physical_blocks_per_logical_kv_block = ratio
|
||||
worker._mamba_phys_ratio = {engine_id: ratio}
|
||||
worker.block_len_per_layer = [100]
|
||||
worker.num_descs = 2 * num_blocks # 800
|
||||
|
||||
fa_blocks = [3, 7] # kernel-level block IDs
|
||||
@@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch():
|
||||
|
||||
# FA group: stride=num_blocks=400, offset=0
|
||||
# region0: [3, 7], region1: [403, 407]
|
||||
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
|
||||
# region0: [801, 802], region1: [901, 902]
|
||||
expected = [3, 7, 403, 407, 801, 802, 901, 902]
|
||||
# SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800,
|
||||
# 4 regions per Mamba layer (x, B, C, ssm)
|
||||
# region0: [801, 802], region1: [901, 902],
|
||||
# region2: [1001, 1002], region3: [1101, 1102]
|
||||
expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102]
|
||||
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
|
||||
|
||||
|
||||
@@ -418,3 +425,29 @@ def test_has_mamba_init(
|
||||
)
|
||||
assert scheduler._has_mamba is expected_has_mamba
|
||||
assert scheduler._is_hma_required is expected_is_hma
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"ssm_sizes,block_len,expected_ratio",
|
||||
[
|
||||
# Nemotron 30B TP=1: ceil((36864 + 2097152) / 8192) = 261
|
||||
((36864, 2097152), 8192, 261),
|
||||
# Nemotron 30B TP=2: ceil((18432 + 1048576) / 4096) = 261
|
||||
((18432, 1048576), 4096, 261),
|
||||
# Nemotron 30B TP=4: ceil((9216 + 524288) / 4096) = 131
|
||||
((9216, 524288), 4096, 131),
|
||||
],
|
||||
)
|
||||
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio):
|
||||
"""Verify that compute_mamba_phys_ratio is TP-dependent.
|
||||
|
||||
With dimension-sharded Mamba state, the ratio differs across TP sizes
|
||||
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
|
||||
_mamba_phys_ratio must be stored per-engine.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
|
||||
compute_mamba_phys_ratio,
|
||||
)
|
||||
|
||||
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio
|
||||
|
||||
Reference in New Issue
Block a user