[Multimodal][XPU]Enable vision attn backend for xpu platform (#27525)

Signed-off-by: Yan Ma <yan.ma@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Yejing Lai <yejing.lai@intel.com>
Co-authored-by: Guancheng Fu <110874468+gc-fu@users.noreply.github.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Yan Ma
2025-11-01 12:45:02 +08:00
committed by GitHub
parent 3a5de7d2d6
commit 7e2729b57e
6 changed files with 88 additions and 51 deletions

View File

@@ -364,6 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
@@ -856,10 +858,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]