[V1][Mamba1] - Full CUDA and Piecewise CUDA Graphs Support (#23035)

Signed-off-by: asafg <asafg@ai21.com>
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin
2025-08-21 06:08:51 +03:00
committed by GitHub
parent 2461d9e562
commit 3663870c72
9 changed files with 154 additions and 87 deletions

View File

@@ -2,16 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class Mamba1AttentionBackend(AttentionBackend):
@@ -31,24 +31,11 @@ class Mamba1AttentionMetadata:
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_padded_decodes: int
class Mamba1AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
vllm_config: VllmConfig,
device: torch.device,
layer_names: list[str],
):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.device = device
self.vllm_config = vllm_config
self.layer_names = layer_names
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]):
def build(
self,
@@ -67,9 +54,18 @@ class Mamba1AttentionMetadataBuilder(
decode_threshold=1))
has_initial_states = None
padded_decodes = num_decodes
if num_prefills > 0:
has_initial_states = context_lens_tensor > 0
elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph):
state_indices_for_decode = state_indices_tensor[:num_decodes]
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(
state_indices_for_decode, non_blocking=True)
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
return Mamba1AttentionMetadata(
query_start_loc=query_start_loc,
@@ -80,4 +76,5 @@ class Mamba1AttentionMetadataBuilder(
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_padded_decodes=padded_decodes,
)