[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-10-02 22:34:53 -07:00
committed by GitHub
parent 27edd2aeb4
commit 9c5ee91b2a
9 changed files with 154 additions and 141 deletions

View File

@@ -42,7 +42,8 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
@@ -319,11 +320,12 @@ class Qwen2VisionAttention(nn.Module):
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
@@ -331,6 +333,7 @@ class Qwen2VisionAttention(nn.Module):
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
@@ -383,25 +386,18 @@ class Qwen2VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
output = self.flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
context_layer = rearrange(output,
"(b s) h d -> s b (h d)",