From af09b3f0a05479e36bab12ae4e05cbdd3fed9d27 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 12 Jun 2025 06:40:24 -0400 Subject: [PATCH] [Bugfix][V1] Allow manual FlashAttention for Blackwell (#19492) Signed-off-by: mgoin --- vllm/platforms/cuda.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 7ab5146fd..2d07ddc36 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -226,15 +226,21 @@ class CudaPlatformBase(Platform): if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" - if selected_backend == _Backend.FLEX_ATTENTION: + elif selected_backend == _Backend.FLEX_ATTENTION: logger.info("Using FlexAttenion backend on V1 engine.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: logger.info_once("Using Triton backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "flash_attn.FlashAttentionBackend") + + # Default backends for V1 engine + # Prefer FlashInfer for Blackwell GPUs if installed if cls.is_device_capability(100): - # Prefer FlashInfer for V1 on Blackwell GPUs if installed try: import flashinfer # noqa: F401 logger.info_once( @@ -248,10 +254,13 @@ class CudaPlatformBase(Platform): "Blackwell (SM 10.0) GPUs; it is recommended to " "install FlashInfer for better performance.") pass - if cls.has_device_capability(80): + # FlashAttention is the default for SM 8.0+ GPUs + elif cls.has_device_capability(80): logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "flash_attn.FlashAttentionBackend") + + # Backends for V0 engine if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend"