[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-10-02 22:34:53 -07:00
committed by GitHub
parent 27edd2aeb4
commit 9c5ee91b2a
9 changed files with 154 additions and 141 deletions

View File

@@ -323,6 +323,7 @@ class Qwen3_VisionTransformer(nn.Module):
head_size=head_dim, dtype=torch.get_default_dtype())
use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
self.attn_backend != _Backend.ROCM_AITER_FA and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
@@ -476,7 +477,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()