[Multimodal][XPU]Enable vision attn backend for xpu platform (#27525)
Signed-off-by: Yan Ma <yan.ma@intel.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Yejing Lai <yejing.lai@intel.com> Co-authored-by: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -364,6 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
|
||||
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
||||
self.use_upstream_fa = True
|
||||
if current_platform.is_xpu():
|
||||
self.use_upstream_fa = False
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
@@ -856,10 +858,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
):
|
||||
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
Reference in New Issue
Block a user