[Mamba][Bugfix] Raise on insufficient cache blocks instead of silently capping cudagraph sizes (#38270)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2026-03-30 11:41:50 +02:00
committed by GitHub
parent 03ac6ca895
commit cc06b4e86b
4 changed files with 41 additions and 131 deletions

View File

@@ -577,48 +577,6 @@ def test_compile_sizes_padding_validation():
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
@pytest.mark.parametrize(
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
[
# Normal capping: sizes filtered to <= num_blocks
(
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
512,
200,
[1, 2, 4, 8, 16, 32, 64, 128],
128,
),
# No capping needed: num_blocks >= max
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
# Exact boundary: num_blocks == max (no capping)
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
# All sizes capped: num_blocks < smallest size
([8, 16, 32], 32, 4, [], 0),
# num_blocks <= 0: early return, no change
([1, 2, 4], 4, 0, [1, 2, 4], 4),
],
)
def test_adjust_cudagraph_sizes_for_mamba_cache(
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
):
"""Test that cudagraph capture sizes are correctly capped to fit
available Mamba cache blocks.
See: https://github.com/vllm-project/vllm/issues/34094
"""
config = CompilationConfig(
cudagraph_capture_sizes=capture_sizes,
max_cudagraph_capture_size=max_size,
cudagraph_mode=CUDAGraphMode.NONE,
)
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
assert config.cudagraph_capture_sizes == expected_sizes
assert config.max_cudagraph_capture_size == expected_max
# Invariant: last element == max_cudagraph_capture_size
if expected_sizes:
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size
def test_inductor_asserts_default_disabled(monkeypatch):
"""Test that inductor runtime asserts are disabled by default
(INFO logging level) on torch < 2.12."""

View File

@@ -1191,9 +1191,9 @@ def test_is_uniform_decode() -> None:
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
)
def test_cudagraph_sizes_capped_for_mamba_cache():
"""Test that cudagraph capture sizes are capped to num_blocks for
hybrid models with Mamba layers.
def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
"""Test that a ValueError is raised when max_num_seqs exceeds the
available Mamba cache blocks for hybrid models with FULL cudagraphs.
See: https://github.com/vllm-project/vllm/issues/34094
"""
@@ -1284,23 +1284,8 @@ def test_cudagraph_sizes_capped_for_mamba_cache():
)[0]
num_blocks = kv_cache_config.num_blocks
# Set max_cudagraph_capture_size to a value larger than num_blocks
# to trigger the Mamba capping logic.
large_max = num_blocks + 100
compilation_config = vllm_config.compilation_config
compilation_config.max_cudagraph_capture_size = large_max
compilation_config.cudagraph_capture_sizes = [
s for s in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] if s <= large_max
]
# Force max_num_seqs to exceed num_blocks so the check triggers.
runner.max_num_reqs = num_blocks + 100
runner.initialize_kv_cache(kv_cache_config)
# After initialization, cudagraph sizes should be capped
assert compilation_config.max_cudagraph_capture_size <= num_blocks
assert all(s <= num_blocks for s in compilation_config.cudagraph_capture_sizes)
# Invariant: last element == max
if compilation_config.cudagraph_capture_sizes:
assert (
compilation_config.cudagraph_capture_sizes[-1]
== compilation_config.max_cudagraph_capture_size
)
with pytest.raises(ValueError, match="max_num_seqs"):
runner.initialize_kv_cache(kv_cache_config)