[Mamba][Bugfix] Raise on insufficient cache blocks instead of silently capping cudagraph sizes (#38270)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -1191,9 +1191,9 @@ def test_is_uniform_decode() -> None:
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||
)
|
||||
def test_cudagraph_sizes_capped_for_mamba_cache():
|
||||
"""Test that cudagraph capture sizes are capped to num_blocks for
|
||||
hybrid models with Mamba layers.
|
||||
def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
|
||||
"""Test that a ValueError is raised when max_num_seqs exceeds the
|
||||
available Mamba cache blocks for hybrid models with FULL cudagraphs.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/34094
|
||||
"""
|
||||
@@ -1284,23 +1284,8 @@ def test_cudagraph_sizes_capped_for_mamba_cache():
|
||||
)[0]
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
|
||||
# Set max_cudagraph_capture_size to a value larger than num_blocks
|
||||
# to trigger the Mamba capping logic.
|
||||
large_max = num_blocks + 100
|
||||
compilation_config = vllm_config.compilation_config
|
||||
compilation_config.max_cudagraph_capture_size = large_max
|
||||
compilation_config.cudagraph_capture_sizes = [
|
||||
s for s in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] if s <= large_max
|
||||
]
|
||||
# Force max_num_seqs to exceed num_blocks so the check triggers.
|
||||
runner.max_num_reqs = num_blocks + 100
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# After initialization, cudagraph sizes should be capped
|
||||
assert compilation_config.max_cudagraph_capture_size <= num_blocks
|
||||
assert all(s <= num_blocks for s in compilation_config.cudagraph_capture_sizes)
|
||||
# Invariant: last element == max
|
||||
if compilation_config.cudagraph_capture_sizes:
|
||||
assert (
|
||||
compilation_config.cudagraph_capture_sizes[-1]
|
||||
== compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
with pytest.raises(ValueError, match="max_num_seqs"):
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
Reference in New Issue
Block a user