Files
vllm/vllm/v1/attention/backends/fa_utils.py
Matthew Bonanni 0308901975 [2/N][Attention] Fix pre-commit errors (#32052)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2026-01-10 00:27:15 +00:00

127 lines
3.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
if current_platform.is_cuda():
from vllm._custom_ops import reshape_and_cache_flash
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
get_scheduler_metadata,
)
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func
get_scheduler_metadata = ipex_ops.get_scheduler_metadata
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
except ImportError as e:
raise ImportError(
"Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
) from e
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
if current_platform.is_xpu():
return 2
if current_platform.is_rocm():
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
return None
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason,
is_fa_version_supported,
)
device_capability = current_platform.get_device_capability()
assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.flash_attn_version is not None
):
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform "
"defaulting to FA version 2."
)
fa_version = 2
if requires_alibi and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
fa_version,
fa_version_unsupported_reason(fa_version),
)
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
def flash_attn_supports_fp8() -> bool:
return (
get_flash_attn_version() == 3
and current_platform.is_device_capability_family(90)
)
def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu():
return True
else:
return get_flash_attn_version() == 3
def flash_attn_supports_mla():
from vllm.platforms import current_platform
if current_platform.is_cuda():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
is_fa_version_supported,
)
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
except (ImportError, AssertionError):
pass
return False
def is_flash_attn_varlen_func_available() -> bool:
return current_platform.is_cuda() or current_platform.is_xpu()