[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-10-02 22:34:53 -07:00
committed by GitHub
parent 27edd2aeb4
commit 9c5ee91b2a
9 changed files with 154 additions and 141 deletions

View File

@@ -47,7 +47,8 @@ from transformers.models.glm4v.video_processing_glm4v import (
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state)
@@ -263,19 +264,26 @@ class Glm4vVisionAttention(nn.Module):
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -316,17 +324,11 @@ class Glm4vVisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
output = self.flash_attn_varlen_func(
q,
k,
v,
@@ -774,7 +776,8 @@ class Glm4vVisionTransformer(nn.Module):
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens