[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -323,6 +323,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
self.attn_backend != _Backend.ROCM_AITER_FA and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
@@ -476,7 +477,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
Reference in New Issue
Block a user