[Bugfix] Cap FULL decode cudagraph sizes for Mamba/hybrid models (#34094) (#34571)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
haosdent
2026-03-04 18:56:22 +08:00
committed by GitHub
parent a8f66cbde8
commit d6e04f4c43
4 changed files with 230 additions and 0 deletions

View File

@@ -570,3 +570,45 @@ def test_compile_sizes_padding_validation():
assert sorted(config.compile_sizes) == [3, 5, 7]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
@pytest.mark.parametrize(
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
[
# Normal capping: sizes filtered to <= num_blocks
(
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
512,
200,
[1, 2, 4, 8, 16, 32, 64, 128],
128,
),
# No capping needed: num_blocks >= max
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
# Exact boundary: num_blocks == max (no capping)
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
# All sizes capped: num_blocks < smallest size
([8, 16, 32], 32, 4, [], 0),
# num_blocks <= 0: early return, no change
([1, 2, 4], 4, 0, [1, 2, 4], 4),
],
)
def test_adjust_cudagraph_sizes_for_mamba_cache(
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
):
"""Test that cudagraph capture sizes are correctly capped to fit
available Mamba cache blocks.
See: https://github.com/vllm-project/vllm/issues/34094
"""
config = CompilationConfig(
cudagraph_capture_sizes=capture_sizes,
max_cudagraph_capture_size=max_size,
cudagraph_mode=CUDAGraphMode.NONE,
)
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
assert config.cudagraph_capture_sizes == expected_sizes
assert config.max_cudagraph_capture_size == expected_max
# Invariant: last element == max_cudagraph_capture_size
if expected_sizes:
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size