[Model] Remove transformers attention porting in VITs (#10414)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-18 21:45:21 +08:00
committed by GitHub
parent 5be4e52b65
commit e7ebb662d7
7 changed files with 139 additions and 102 deletions

View File

@@ -587,7 +587,11 @@ class LLMWrapper(nn.Module):
return llm(*args, **kwargs)
def get_vit_attn_backend() -> _Backend:
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
@@ -596,7 +600,7 @@ def get_vit_attn_backend() -> _Backend:
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available:
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
@@ -606,7 +610,8 @@ def get_vit_attn_backend() -> _Backend:
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu():
elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS