[Bugfix] Cap FULL decode cudagraph sizes for Mamba/hybrid models (#34094) (#34571)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
haosdent
2026-03-04 18:56:22 +08:00
committed by GitHub
parent a8f66cbde8
commit d6e04f4c43
4 changed files with 230 additions and 0 deletions

View File

@@ -570,3 +570,45 @@ def test_compile_sizes_padding_validation():
assert sorted(config.compile_sizes) == [3, 5, 7]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
@pytest.mark.parametrize(
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
[
# Normal capping: sizes filtered to <= num_blocks
(
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
512,
200,
[1, 2, 4, 8, 16, 32, 64, 128],
128,
),
# No capping needed: num_blocks >= max
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
# Exact boundary: num_blocks == max (no capping)
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
# All sizes capped: num_blocks < smallest size
([8, 16, 32], 32, 4, [], 0),
# num_blocks <= 0: early return, no change
([1, 2, 4], 4, 0, [1, 2, 4], 4),
],
)
def test_adjust_cudagraph_sizes_for_mamba_cache(
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
):
"""Test that cudagraph capture sizes are correctly capped to fit
available Mamba cache blocks.
See: https://github.com/vllm-project/vllm/issues/34094
"""
config = CompilationConfig(
cudagraph_capture_sizes=capture_sizes,
max_cudagraph_capture_size=max_size,
cudagraph_mode=CUDAGraphMode.NONE,
)
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
assert config.cudagraph_capture_sizes == expected_sizes
assert config.max_cudagraph_capture_size == expected_max
# Invariant: last element == max_cudagraph_capture_size
if expected_sizes:
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size

View File

@@ -1199,3 +1199,123 @@ def test_is_uniform_decode() -> None:
num_reqs=15,
force_uniform_decode=False,
)
@pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
)
def test_cudagraph_sizes_capped_for_mamba_cache():
"""Test that cudagraph capture sizes are capped to num_blocks for
hybrid models with Mamba layers.
See: https://github.com/vllm-project/vllm/issues/34094
"""
set_random_seed(42)
update_environment_variables(
{
"RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
from tests.utils import ensure_current_vllm_config
with ensure_current_vllm_config():
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=1)
torch.set_default_dtype(torch.float16)
model_config = ModelConfig(
model="ibm-granite/granite-4.0-tiny-preview",
dtype="float16",
)
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig(
block_size=BLOCK_SIZE,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
parallel_config = ParallelConfig()
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
attention_config=attention_config,
)
with set_current_vllm_config(vllm_config):
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in ["model.layers.0.self_attn.attn", "model.layers.1.self_attn.attn"]:
fwd_context[key] = Attention(
num_heads=model_config.get_num_attention_heads(parallel_config),
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
scale=1.0,
prefix=key,
)
for key in [
"model.layers.2.mixer",
"model.layers.3.mixer",
"model.layers.4.mixer",
"model.layers.5.mixer",
]:
fwd_context[key] = MambaMixer2(
hidden_size=hf_config.hidden_size,
ssm_state_size=hf_config.mamba_d_state,
conv_kernel_size=hf_config.mamba_d_conv,
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
use_conv_bias=hf_config.mamba_conv_bias,
use_bias=hf_config.mamba_proj_bias,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
rms_norm_eps=hf_config.rms_norm_eps,
activation=hf_config.hidden_act,
cache_config=cache_config,
model_config=model_config,
prefix=key,
)
assert fwd_context is not None
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
available_memory = 5 * GiB_bytes
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
num_blocks = kv_cache_config.num_blocks
# Set max_cudagraph_capture_size to a value larger than num_blocks
# to trigger the Mamba capping logic.
large_max = num_blocks + 100
compilation_config = vllm_config.compilation_config
compilation_config.max_cudagraph_capture_size = large_max
compilation_config.cudagraph_capture_sizes = [
s for s in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] if s <= large_max
]
runner.initialize_kv_cache(kv_cache_config)
# After initialization, cudagraph sizes should be capped
assert compilation_config.max_cudagraph_capture_size <= num_blocks
assert all(s <= num_blocks for s in compilation_config.cudagraph_capture_sizes)
# Invariant: last element == max
if compilation_config.cudagraph_capture_sizes:
assert (
compilation_config.cudagraph_capture_sizes[-1]
== compilation_config.max_cudagraph_capture_size
)

View File

@@ -1190,6 +1190,58 @@ class CompilationConfig:
self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes
def adjust_cudagraph_sizes_for_mamba_cache(
self, num_mamba_cache_blocks: int
) -> None:
"""Cap cudagraph capture sizes to available Mamba cache blocks.
For hybrid Mamba/attention models, the Mamba conv_state and
ssm_state tensors have their first dimension equal to num_blocks
(from KVCacheConfig). During CUDA graph capture the decode batch
size equals num_tokens, so capture sizes exceeding num_blocks
would cause out-of-bounds access in Mamba kernels.
See: https://github.com/vllm-project/vllm/issues/34094
"""
if not self.cudagraph_capture_sizes or num_mamba_cache_blocks <= 0:
return
assert self.max_cudagraph_capture_size is not None
if num_mamba_cache_blocks >= self.max_cudagraph_capture_size:
return
capped_sizes = [
s for s in self.cudagraph_capture_sizes if s <= num_mamba_cache_blocks
]
if len(capped_sizes) == 0:
logger.warning(
"No valid cudagraph capture sizes remain after capping "
"to Mamba cache blocks (%d). The smallest capture size "
"was %d. Disabling cudagraph capture. Consider reducing "
"max_num_seqs or increasing available GPU memory.",
num_mamba_cache_blocks,
self.cudagraph_capture_sizes[0],
)
self.cudagraph_capture_sizes = []
self.max_cudagraph_capture_size = 0
return
logger.warning(
"Capping cudagraph capture sizes from max %d to %d to fit "
"Mamba cache blocks (%d blocks available). This limits the "
"maximum batch size that can use CUDA graphs. To increase "
"this limit, reduce max_num_seqs or increase available GPU "
"memory.",
self.max_cudagraph_capture_size,
capped_sizes[-1],
num_mamba_cache_blocks,
)
self.max_cudagraph_capture_size = capped_sizes[-1]
self.cudagraph_capture_sizes = capped_sizes
def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:

View File

@@ -5700,6 +5700,22 @@ class GPUModelRunner(
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
)
# If the model has Mamba layers and cudagraph mode includes FULL
# decode, cap cudagraph capture sizes to the number of available
# Mamba cache blocks. Each decode request needs one conv_state
# cache line, so capture batch sizes cannot exceed num_blocks.
# Only FULL decode graphs are affected because PIECEWISE captures
# run GDN/Mamba ops eagerly (prefill path, no causal_conv1d_update).
# See: https://github.com/vllm-project/vllm/issues/34094
if cudagraph_mode.has_full_cudagraphs():
has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups
)
if has_mamba and self.kv_cache_config is not None:
self.compilation_config.adjust_cudagraph_sizes_for_mamba_cache(
self.kv_cache_config.num_blocks
)
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
self.compilation_config.cudagraph_mode = cudagraph_mode