[Bugfix][Kernel] Fix CUDA 11.8 being broken by FA3 build (#12375)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-01-24 10:27:59 -05:00
committed by GitHub
parent 3bb8e2c9a2
commit ab5bbf5ae3
6 changed files with 42 additions and 22 deletions

11
vllm/v1/attention/backends/flash_attn.py Normal file → Executable file
View File

@@ -10,11 +10,15 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)
logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
@@ -143,6 +147,11 @@ class FlashAttentionImpl(AttentionImpl):
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward(