[Bugfix][V1] Allow manual FlashAttention for Blackwell (#19492)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -226,15 +226,21 @@ class CudaPlatformBase(Platform):
|
|||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||||
if selected_backend == _Backend.FLEX_ATTENTION:
|
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||||
logger.info("Using FlexAttenion backend on V1 engine.")
|
logger.info("Using FlexAttenion backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||||
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||||
logger.info_once("Using Triton backend on V1 engine.")
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends."
|
return ("vllm.v1.attention.backends."
|
||||||
"triton_attn.TritonAttentionBackend")
|
"triton_attn.TritonAttentionBackend")
|
||||||
|
elif selected_backend == _Backend.FLASH_ATTN:
|
||||||
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
|
return ("vllm.v1.attention.backends."
|
||||||
|
"flash_attn.FlashAttentionBackend")
|
||||||
|
|
||||||
|
# Default backends for V1 engine
|
||||||
|
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||||
if cls.is_device_capability(100):
|
if cls.is_device_capability(100):
|
||||||
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
|
|
||||||
try:
|
try:
|
||||||
import flashinfer # noqa: F401
|
import flashinfer # noqa: F401
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
@@ -248,10 +254,13 @@ class CudaPlatformBase(Platform):
|
|||||||
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
||||||
"install FlashInfer for better performance.")
|
"install FlashInfer for better performance.")
|
||||||
pass
|
pass
|
||||||
if cls.has_device_capability(80):
|
# FlashAttention is the default for SM 8.0+ GPUs
|
||||||
|
elif cls.has_device_capability(80):
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends."
|
return ("vllm.v1.attention.backends."
|
||||||
"flash_attn.FlashAttentionBackend")
|
"flash_attn.FlashAttentionBackend")
|
||||||
|
|
||||||
|
# Backends for V0 engine
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
logger.info("Using FlashInfer backend.")
|
logger.info("Using FlashInfer backend.")
|
||||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||||
|
|||||||
Reference in New Issue
Block a user