[FEAT][ROCm] Enable running Flash Attention as ViT attn backend for Qwen-VL models on ROCm platform. (#22069)

Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
vllmellm
2025-08-02 14:53:18 +08:00
committed by GitHub
parent 0edaf752d7
commit d3a6f2120b
6 changed files with 64 additions and 39 deletions

View File

@@ -7,9 +7,7 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union
import torch
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.attention.selector import get_env_variable_attn_backend
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
@@ -75,32 +73,12 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
if current_platform.is_cuda():
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision "
"module, so we use xformers backend instead. You can "
"run `pip install flash-attn` to use flash-attention "
"backend.")
selected_backend = _Backend.XFORMERS
else:
# For Volta and Turing GPUs, use xformers instead.
selected_backend = _Backend.XFORMERS
else:
# Default to torch SDPA for other non-GPU platforms.
selected_backend = _Backend.TORCH_SDPA
return selected_backend
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
if selected_backend is not None:
return selected_backend
return current_platform.get_vit_attn_backend(support_fa)
def resolve_visual_encoder_outputs(