[V1][Mamba1] - Full CUDA and Piecewise CUDA Graphs Support (#23035)

Signed-off-by: asafg <asafg@ai21.com>
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin
2025-08-21 06:08:51 +03:00
committed by GitHub
parent 2461d9e562
commit 3663870c72
9 changed files with 154 additions and 87 deletions

View File

@@ -2,16 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class Mamba1AttentionBackend(AttentionBackend):
@@ -31,24 +31,11 @@ class Mamba1AttentionMetadata:
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_padded_decodes: int
class Mamba1AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
vllm_config: VllmConfig,
device: torch.device,
layer_names: list[str],
):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.device = device
self.vllm_config = vllm_config
self.layer_names = layer_names
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]):
def build(
self,
@@ -67,9 +54,18 @@ class Mamba1AttentionMetadataBuilder(
decode_threshold=1))
has_initial_states = None
padded_decodes = num_decodes
if num_prefills > 0:
has_initial_states = context_lens_tensor > 0
elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph):
state_indices_for_decode = state_indices_tensor[:num_decodes]
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(
state_indices_for_decode, non_blocking=True)
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
return Mamba1AttentionMetadata(
query_start_loc=query_start_loc,
@@ -80,4 +76,5 @@ class Mamba1AttentionMetadataBuilder(
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_padded_decodes=padded_decodes,
)

View File

@@ -2,18 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
from vllm.v1.kv_cache_interface import AttentionSpec
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
@@ -88,29 +88,14 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build(self,
common_prefix_len: int,
@@ -187,19 +172,3 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor=state_indices_tensor,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)

View File

@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from typing import ClassVar, TypeVar
import torch
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
M = TypeVar("M")
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: ClassVar[int] = 1
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.device = device
self.vllm_config = vllm_config
self.layer_names = layer_names
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)