[Misc] Log the reason for falling back to FlexAttention (#20699)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Generator, Optional, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
@@ -79,32 +80,62 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
|||||||
return forced_attn_backend
|
return forced_attn_backend
|
||||||
|
|
||||||
|
|
||||||
def supports_head_size(
|
@dataclass(frozen=True)
|
||||||
|
class _IsSupported:
|
||||||
|
can_import: bool
|
||||||
|
head_size: bool
|
||||||
|
dtype: bool
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return self.can_import and self.head_size and self.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def is_attn_backend_supported(
|
||||||
attn_backend: Union[str, type[AttentionBackend]],
|
attn_backend: Union[str, type[AttentionBackend]],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
) -> bool:
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
allow_import_error: bool = True,
|
||||||
|
) -> _IsSupported:
|
||||||
if isinstance(attn_backend, str):
|
if isinstance(attn_backend, str):
|
||||||
try:
|
try:
|
||||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
if not allow_import_error:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return _IsSupported(can_import=False, head_size=False, dtype=False)
|
||||||
|
|
||||||
assert isinstance(attn_backend, type)
|
assert isinstance(attn_backend, type)
|
||||||
|
|
||||||
# TODO: Update the interface once V0 is removed
|
# TODO: Update the interface once V0 is removed
|
||||||
if get_supported_head_sizes := getattr(attn_backend,
|
if get_supported_head_sizes := getattr(attn_backend,
|
||||||
"get_supported_head_sizes", None):
|
"get_supported_head_sizes", None):
|
||||||
return head_size in get_supported_head_sizes()
|
is_head_size_supported = head_size in get_supported_head_sizes()
|
||||||
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
|
elif validate_head_size := getattr(attn_backend, "validate_head_size",
|
||||||
|
None):
|
||||||
try:
|
try:
|
||||||
validate_head_size(head_size)
|
validate_head_size(head_size)
|
||||||
return True
|
is_head_size_supported = True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
is_head_size_supported = False
|
||||||
|
else:
|
||||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||||
"head size validation")
|
"head size validation")
|
||||||
|
|
||||||
|
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
|
||||||
|
None):
|
||||||
|
is_dtype_supported = dtype in get_supported_dtypes()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||||
|
"dtype validation")
|
||||||
|
|
||||||
|
return _IsSupported(
|
||||||
|
can_import=True,
|
||||||
|
head_size=is_head_size_supported,
|
||||||
|
dtype=is_dtype_supported,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_attn_backend(
|
def get_attn_backend(
|
||||||
head_size: int,
|
head_size: int,
|
||||||
|
|||||||
@@ -259,45 +259,58 @@ class CudaPlatformBase(Platform):
|
|||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
return FLASH_ATTN_V1
|
return FLASH_ATTN_V1
|
||||||
|
|
||||||
from vllm.attention.selector import supports_head_size
|
from vllm.attention.selector import is_attn_backend_supported
|
||||||
|
|
||||||
# Default backends for V1 engine
|
# Default backends for V1 engine
|
||||||
# FP32 is only supported by FlexAttention
|
|
||||||
if dtype not in (torch.float16, torch.bfloat16):
|
|
||||||
logger.info_once(
|
|
||||||
"Using FlexAttention backend for %s on V1 engine.",
|
|
||||||
dtype,
|
|
||||||
)
|
|
||||||
return FLEX_ATTENTION_V1
|
|
||||||
|
|
||||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||||
if cls.is_device_capability(100) and \
|
if cls.is_device_capability(100):
|
||||||
supports_head_size(FLASHINFER_V1, head_size):
|
if is_default_backend_supported := is_attn_backend_supported(
|
||||||
try:
|
FLASHINFER_V1, head_size, dtype):
|
||||||
import flashinfer # noqa: F401
|
|
||||||
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
set_kv_cache_layout)
|
set_kv_cache_layout)
|
||||||
|
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using FlashInfer backend with HND KV cache layout on "
|
"Using FlashInfer backend with HND KV cache layout on "
|
||||||
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
|
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
|
||||||
set_kv_cache_layout("HND")
|
set_kv_cache_layout("HND")
|
||||||
|
|
||||||
return FLASHINFER_V1
|
return FLASHINFER_V1
|
||||||
except ImportError:
|
|
||||||
logger.info_once(
|
if not is_default_backend_supported.can_import:
|
||||||
|
logger.warning_once(
|
||||||
"FlashInfer failed to import for V1 engine on "
|
"FlashInfer failed to import for V1 engine on "
|
||||||
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
||||||
"install FlashInfer for better performance.")
|
"install FlashInfer for better performance.")
|
||||||
pass
|
|
||||||
# FlashAttention is the default for SM 8.0+ GPUs
|
# FlashAttention is the default for SM 8.0+ GPUs
|
||||||
if cls.has_device_capability(80) and \
|
if cls.has_device_capability(80):
|
||||||
supports_head_size(FLASH_ATTN_V1, head_size):
|
if is_default_backend_supported := is_attn_backend_supported(
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
FLASH_ATTN_V1, head_size, dtype,
|
||||||
|
allow_import_error=False):
|
||||||
|
logger.info_once("Using Flash Attention backend on "
|
||||||
|
"V1 engine.")
|
||||||
return FLASH_ATTN_V1
|
return FLASH_ATTN_V1
|
||||||
|
|
||||||
|
# FlexAttention is the default for older GPUs
|
||||||
|
else:
|
||||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||||
return FLEX_ATTENTION_V1
|
return FLEX_ATTENTION_V1
|
||||||
|
|
||||||
|
assert not is_default_backend_supported
|
||||||
|
|
||||||
|
use_flex_attention_reason = {}
|
||||||
|
if not is_default_backend_supported.head_size:
|
||||||
|
use_flex_attention_reason["head_size"] = head_size
|
||||||
|
if not is_default_backend_supported.dtype:
|
||||||
|
use_flex_attention_reason["dtype"] = dtype
|
||||||
|
|
||||||
|
logger.info_once(
|
||||||
|
"Using FlexAttention backend for %s on V1 engine.",
|
||||||
|
", ".join(f"{k}={v}"
|
||||||
|
for k, v in use_flex_attention_reason.items()),
|
||||||
|
)
|
||||||
|
return FLEX_ATTENTION_V1
|
||||||
|
|
||||||
# Backends for V0 engine
|
# Backends for V0 engine
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
logger.info("Using FlashInfer backend.")
|
logger.info("Using FlashInfer backend.")
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import re
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import regex as re
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ logger = init_logger(__name__)
|
|||||||
class TorchSDPABackend(AttentionBackend):
|
class TorchSDPABackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = False
|
accept_output_buffer: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_head_size(cls, head_size: int) -> None:
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
attn_impl = _get_paged_attn_impl()
|
attn_impl = _get_paged_attn_impl()
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
cached_sm100a_supported: Optional[bool] = None
|
cached_sm100a_supported: Optional[bool] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
|
|||||||
class FlexAttentionBackend(AttentionBackend):
|
class FlexAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_head_size(cls, head_size: int) -> None:
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
return # FlexAttention supports any head size
|
return # FlexAttention supports any head size
|
||||||
|
|||||||
@@ -262,6 +262,10 @@ class MLACommonBackend(AttentionBackend):
|
|||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
return (num_blocks, block_size, head_size)
|
return (num_blocks, block_size, head_size)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [576]
|
return [576]
|
||||||
|
|||||||
@@ -314,6 +314,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|||||||
@@ -190,6 +190,10 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|||||||
Reference in New Issue
Block a user