[V1][CUDA] Full cudagraph support for FlashInfer (#21367)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import abc
|
||||
import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
@@ -65,9 +66,24 @@ class CommonAttentionMetadata:
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class AttentionCGSupport(enum.Enum):
|
||||
""" Constants for the cudagraph support of the attention backend
|
||||
Here we do not consider the cascade attention, as currently
|
||||
it is never cudagraph supported."""
|
||||
|
||||
NEVER = 0
|
||||
"""NO cudagraph support"""
|
||||
PURE_DECODE_ONLY = 1
|
||||
"""Cudagraph supported for pure decode, need to run without
|
||||
cudagraph for mixed prefill-decode batches"""
|
||||
ALWAYS = 2
|
||||
"""Cudagraph always supported"""
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
full_cudagraph_supported: ClassVar[bool] = False
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
|
||||
Reference in New Issue
Block a user