[Mamba][Bugfix] Raise on insufficient cache blocks instead of silently capping cudagraph sizes (#38270)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -577,48 +577,6 @@ def test_compile_sizes_padding_validation():
|
||||
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
|
||||
[
|
||||
# Normal capping: sizes filtered to <= num_blocks
|
||||
(
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
|
||||
512,
|
||||
200,
|
||||
[1, 2, 4, 8, 16, 32, 64, 128],
|
||||
128,
|
||||
),
|
||||
# No capping needed: num_blocks >= max
|
||||
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
|
||||
# Exact boundary: num_blocks == max (no capping)
|
||||
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
|
||||
# All sizes capped: num_blocks < smallest size
|
||||
([8, 16, 32], 32, 4, [], 0),
|
||||
# num_blocks <= 0: early return, no change
|
||||
([1, 2, 4], 4, 0, [1, 2, 4], 4),
|
||||
],
|
||||
)
|
||||
def test_adjust_cudagraph_sizes_for_mamba_cache(
|
||||
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
|
||||
):
|
||||
"""Test that cudagraph capture sizes are correctly capped to fit
|
||||
available Mamba cache blocks.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/34094
|
||||
"""
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=capture_sizes,
|
||||
max_cudagraph_capture_size=max_size,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
|
||||
assert config.cudagraph_capture_sizes == expected_sizes
|
||||
assert config.max_cudagraph_capture_size == expected_max
|
||||
# Invariant: last element == max_cudagraph_capture_size
|
||||
if expected_sizes:
|
||||
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size
|
||||
|
||||
|
||||
def test_inductor_asserts_default_disabled(monkeypatch):
|
||||
"""Test that inductor runtime asserts are disabled by default
|
||||
(INFO logging level) on torch < 2.12."""
|
||||
|
||||
@@ -1191,9 +1191,9 @@ def test_is_uniform_decode() -> None:
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||
)
|
||||
def test_cudagraph_sizes_capped_for_mamba_cache():
|
||||
"""Test that cudagraph capture sizes are capped to num_blocks for
|
||||
hybrid models with Mamba layers.
|
||||
def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
|
||||
"""Test that a ValueError is raised when max_num_seqs exceeds the
|
||||
available Mamba cache blocks for hybrid models with FULL cudagraphs.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/34094
|
||||
"""
|
||||
@@ -1284,23 +1284,8 @@ def test_cudagraph_sizes_capped_for_mamba_cache():
|
||||
)[0]
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
|
||||
# Set max_cudagraph_capture_size to a value larger than num_blocks
|
||||
# to trigger the Mamba capping logic.
|
||||
large_max = num_blocks + 100
|
||||
compilation_config = vllm_config.compilation_config
|
||||
compilation_config.max_cudagraph_capture_size = large_max
|
||||
compilation_config.cudagraph_capture_sizes = [
|
||||
s for s in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] if s <= large_max
|
||||
]
|
||||
# Force max_num_seqs to exceed num_blocks so the check triggers.
|
||||
runner.max_num_reqs = num_blocks + 100
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# After initialization, cudagraph sizes should be capped
|
||||
assert compilation_config.max_cudagraph_capture_size <= num_blocks
|
||||
assert all(s <= num_blocks for s in compilation_config.cudagraph_capture_sizes)
|
||||
# Invariant: last element == max
|
||||
if compilation_config.cudagraph_capture_sizes:
|
||||
assert (
|
||||
compilation_config.cudagraph_capture_sizes[-1]
|
||||
== compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
with pytest.raises(ValueError, match="max_num_seqs"):
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
@@ -1279,58 +1279,6 @@ class CompilationConfig:
|
||||
self.max_cudagraph_capture_size = rounded_sizes[-1]
|
||||
self.cudagraph_capture_sizes = rounded_sizes
|
||||
|
||||
def adjust_cudagraph_sizes_for_mamba_cache(
|
||||
self, num_mamba_cache_blocks: int
|
||||
) -> None:
|
||||
"""Cap cudagraph capture sizes to available Mamba cache blocks.
|
||||
|
||||
For hybrid Mamba/attention models, the Mamba conv_state and
|
||||
ssm_state tensors have their first dimension equal to num_blocks
|
||||
(from KVCacheConfig). During CUDA graph capture the decode batch
|
||||
size equals num_tokens, so capture sizes exceeding num_blocks
|
||||
would cause out-of-bounds access in Mamba kernels.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/34094
|
||||
"""
|
||||
if not self.cudagraph_capture_sizes or num_mamba_cache_blocks <= 0:
|
||||
return
|
||||
|
||||
assert self.max_cudagraph_capture_size is not None
|
||||
|
||||
if num_mamba_cache_blocks >= self.max_cudagraph_capture_size:
|
||||
return
|
||||
|
||||
capped_sizes = [
|
||||
s for s in self.cudagraph_capture_sizes if s <= num_mamba_cache_blocks
|
||||
]
|
||||
|
||||
if len(capped_sizes) == 0:
|
||||
logger.warning(
|
||||
"No valid cudagraph capture sizes remain after capping "
|
||||
"to Mamba cache blocks (%d). The smallest capture size "
|
||||
"was %d. Disabling cudagraph capture. Consider reducing "
|
||||
"max_num_seqs or increasing available GPU memory.",
|
||||
num_mamba_cache_blocks,
|
||||
self.cudagraph_capture_sizes[0],
|
||||
)
|
||||
self.cudagraph_capture_sizes = []
|
||||
self.max_cudagraph_capture_size = 0
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Capping cudagraph capture sizes from max %d to %d to fit "
|
||||
"Mamba cache blocks (%d blocks available). This limits the "
|
||||
"maximum batch size that can use CUDA graphs. To increase "
|
||||
"this limit, reduce max_num_seqs or increase available GPU "
|
||||
"memory.",
|
||||
self.max_cudagraph_capture_size,
|
||||
capped_sizes[-1],
|
||||
num_mamba_cache_blocks,
|
||||
)
|
||||
|
||||
self.max_cudagraph_capture_size = capped_sizes[-1]
|
||||
self.cudagraph_capture_sizes = capped_sizes
|
||||
|
||||
def get_compile_ranges(self) -> list[Range]:
|
||||
"""Get the compile ranges for the compilation config."""
|
||||
if self.compile_ranges_endpoints is None:
|
||||
|
||||
@@ -5800,7 +5800,7 @@ class GPUModelRunner(
|
||||
)
|
||||
self.cache_config.num_gpu_blocks_override = saved_override
|
||||
|
||||
self.initialize_kv_cache(minimal_config)
|
||||
self.initialize_kv_cache(minimal_config, is_profiling=True)
|
||||
self.cache_config.num_gpu_blocks = minimal_config.num_blocks
|
||||
|
||||
logger.debug("Initialized minimal KV cache for CUDA graph profiling")
|
||||
@@ -6121,7 +6121,11 @@ class GPUModelRunner(
|
||||
torch.accelerator.synchronize()
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
def initialize_attn_backend(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
is_profiling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the attention backends and attention metadata builders.
|
||||
"""
|
||||
@@ -6193,7 +6197,9 @@ class GPUModelRunner(
|
||||
|
||||
# Resolve cudagraph_mode before actually initialize metadata_builders
|
||||
self._check_and_update_cudagraph_mode(
|
||||
attention_backend_list, kv_cache_config.kv_cache_groups
|
||||
attention_backend_list,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
is_profiling=is_profiling,
|
||||
)
|
||||
|
||||
# Check if attention backend supports PCP&DCP and related features.
|
||||
@@ -6237,6 +6243,7 @@ class GPUModelRunner(
|
||||
self,
|
||||
attention_backends: list[set[type[AttentionBackend]]],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
is_profiling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Resolve the cudagraph_mode when there are multiple attention
|
||||
@@ -6377,21 +6384,29 @@ class GPUModelRunner(
|
||||
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
# If the model has Mamba layers and cudagraph mode includes FULL
|
||||
# decode, cap cudagraph capture sizes to the number of available
|
||||
# Mamba cache blocks. Each decode request needs one conv_state
|
||||
# cache line, so capture batch sizes cannot exceed num_blocks.
|
||||
# Only FULL decode graphs are affected because PIECEWISE captures
|
||||
# run GDN/Mamba ops eagerly (prefill path, no causal_conv1d_update).
|
||||
# For Mamba models with FULL decode cudagraphs, each decode
|
||||
# sequence needs one Mamba cache block. The decode cudagraph
|
||||
# dispatcher already caps batch sizes at max_num_seqs, so we just
|
||||
# need to verify that enough blocks exist. Raising here instead
|
||||
# of silently capping cudagraph_capture_sizes avoids unintended
|
||||
# restrictions on PIECEWISE (prefill) cudagraphs.
|
||||
# See: https://github.com/vllm-project/vllm/issues/34094
|
||||
if cudagraph_mode.has_full_cudagraphs():
|
||||
if cudagraph_mode.has_full_cudagraphs() and not is_profiling:
|
||||
has_mamba = any(
|
||||
isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups
|
||||
)
|
||||
if has_mamba and self.kv_cache_config is not None:
|
||||
self.compilation_config.adjust_cudagraph_sizes_for_mamba_cache(
|
||||
self.kv_cache_config.num_blocks
|
||||
)
|
||||
num_blocks = self.kv_cache_config.num_blocks
|
||||
if self.max_num_reqs > num_blocks:
|
||||
raise ValueError(
|
||||
f"max_num_seqs ({self.max_num_reqs}) exceeds "
|
||||
f"available Mamba cache blocks ({num_blocks}). "
|
||||
f"Each decode sequence requires one Mamba cache "
|
||||
f"block, so CUDA graph capture cannot proceed. "
|
||||
f"Please lower max_num_seqs to at most "
|
||||
f"{num_blocks} or increase "
|
||||
f"gpu_memory_utilization."
|
||||
)
|
||||
|
||||
# Trigger cudagraph dispatching keys initialization after
|
||||
# resolved cudagraph mode.
|
||||
@@ -6752,7 +6767,11 @@ class GPUModelRunner(
|
||||
else:
|
||||
break
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
def initialize_kv_cache(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
is_profiling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
Args:
|
||||
@@ -6764,7 +6783,7 @@ class GPUModelRunner(
|
||||
self._mamba_copy_bufs = None
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config, is_profiling=is_profiling)
|
||||
# The kernel block size for all KV cache groups. For example, if
|
||||
# kv_cache_manager uses block_size 256 for a given group, but the attention
|
||||
# backends for that group only supports block_size 64, we will return
|
||||
|
||||
Reference in New Issue
Block a user