[FEAT][ROCm] Enable running Flash Attention as ViT attn backend for Qwen-VL models on ROCm platform. (#22069)
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
@@ -246,11 +246,15 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
@@ -297,10 +301,13 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if self.is_flash_attn_backend:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
# flash_attn_varlen_func)
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
@@ -311,7 +318,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
|
||||
context_layer = rearrange(output,
|
||||
@@ -635,7 +642,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
@@ -274,10 +274,14 @@ class Qwen2VisionAttention(nn.Module):
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
@@ -324,10 +328,13 @@ class Qwen2VisionAttention(nn.Module):
|
||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if self.is_flash_attn_backend:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
# flash_attn_varlen_func)
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
@@ -338,7 +345,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
|
||||
context_layer = rearrange(output,
|
||||
@@ -620,7 +627,8 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
self, cu_seqlens: torch.Tensor
|
||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
@@ -7,9 +7,7 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.attention.selector import get_env_variable_attn_backend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
|
||||
@@ -75,32 +73,12 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
||||
Get the available attention backend for Vision Transformer.
|
||||
"""
|
||||
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
|
||||
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
|
||||
if selected_backend is None:
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
if current_platform.is_cuda():
|
||||
device_available = current_platform.has_device_capability(80)
|
||||
if device_available and support_fa:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
if is_flash_attn_2_available():
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Current `vllm-flash-attn` has a bug inside vision "
|
||||
"module, so we use xformers backend instead. You can "
|
||||
"run `pip install flash-attn` to use flash-attention "
|
||||
"backend.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
else:
|
||||
# For Volta and Turing GPUs, use xformers instead.
|
||||
selected_backend = _Backend.XFORMERS
|
||||
else:
|
||||
# Default to torch SDPA for other non-GPU platforms.
|
||||
selected_backend = _Backend.TORCH_SDPA
|
||||
return selected_backend
|
||||
|
||||
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
|
||||
if selected_backend is not None:
|
||||
return selected_backend
|
||||
|
||||
return current_platform.get_vit_attn_backend(support_fa)
|
||||
|
||||
|
||||
def resolve_visual_encoder_outputs(
|
||||
|
||||
Reference in New Issue
Block a user