[V1] [Hybrid] Enable Full CUDA Graph (decode-only) for Mamba layers (#21401)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-08-10 05:16:11 +02:00
committed by GitHub
parent 42172ad18f
commit 61f67d8acd
2 changed files with 103 additions and 1 deletions

View File

@@ -7,8 +7,10 @@ from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
reorder_batch_threshold: ClassVar[int] = 1
@@ -90,8 +94,18 @@ class Mamba2AttentionMetadataBuilder(
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build(self,
common_prefix_len: int,
@@ -144,6 +158,14 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p, self.chunk_size,
num_prefill_tokens))
elif num_decodes <= self.decode_cudagraph_max_bs:
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
non_blocking=True)
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
@@ -160,3 +182,23 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor=state_indices_tensor,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1