[Mamba][Bugfix] Raise on insufficient cache blocks instead of silently capping cudagraph sizes (#38270)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -577,48 +577,6 @@ def test_compile_sizes_padding_validation():
|
||||
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
|
||||
[
|
||||
# Normal capping: sizes filtered to <= num_blocks
|
||||
(
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
|
||||
512,
|
||||
200,
|
||||
[1, 2, 4, 8, 16, 32, 64, 128],
|
||||
128,
|
||||
),
|
||||
# No capping needed: num_blocks >= max
|
||||
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
|
||||
# Exact boundary: num_blocks == max (no capping)
|
||||
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
|
||||
# All sizes capped: num_blocks < smallest size
|
||||
([8, 16, 32], 32, 4, [], 0),
|
||||
# num_blocks <= 0: early return, no change
|
||||
([1, 2, 4], 4, 0, [1, 2, 4], 4),
|
||||
],
|
||||
)
|
||||
def test_adjust_cudagraph_sizes_for_mamba_cache(
|
||||
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
|
||||
):
|
||||
"""Test that cudagraph capture sizes are correctly capped to fit
|
||||
available Mamba cache blocks.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/34094
|
||||
"""
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=capture_sizes,
|
||||
max_cudagraph_capture_size=max_size,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
|
||||
assert config.cudagraph_capture_sizes == expected_sizes
|
||||
assert config.max_cudagraph_capture_size == expected_max
|
||||
# Invariant: last element == max_cudagraph_capture_size
|
||||
if expected_sizes:
|
||||
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size
|
||||
|
||||
|
||||
def test_inductor_asserts_default_disabled(monkeypatch):
|
||||
"""Test that inductor runtime asserts are disabled by default
|
||||
(INFO logging level) on torch < 2.12."""
|
||||
|
||||
Reference in New Issue
Block a user