[V1][CUDA] Full cudagraph support for FlashInfer (#21367)

This commit is contained in:
fhl2000
2025-08-02 09:49:34 +08:00
committed by GitHub
parent 3654847db5
commit 23322431c8
8 changed files with 376 additions and 47 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
@@ -65,9 +66,24 @@ class CommonAttentionMetadata:
M = TypeVar("M")
class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],