[FEAT][ROCm] Enable running Flash Attention as ViT attn backend for Qwen-VL models on ROCm platform. (#22069)

Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
vllmellm
2025-08-02 14:53:18 +08:00
committed by GitHub
parent 0edaf752d7
commit d3a6f2120b
6 changed files with 64 additions and 39 deletions

View File

@@ -246,11 +246,15 @@ class Qwen2_5_VisionAttention(nn.Module):
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
@@ -297,10 +301,13 @@ class Qwen2_5_VisionAttention(nn.Module):
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.attn_backend == _Backend.FLASH_ATTN:
if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
@@ -311,7 +318,7 @@ class Qwen2_5_VisionAttention(nn.Module):
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
dropout_p=0.0,
causal=False)
context_layer = rearrange(output,
@@ -635,7 +642,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()