[XPU] allow TORCH_SDPA/TRITON_ATTN as XPU vit Backend (#35010)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2026-02-23 21:06:44 +08:00
committed by GitHub
parent 5f68464f92
commit b1b5e045df
2 changed files with 12 additions and 4 deletions

View File

@@ -249,7 +249,14 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.is_flash_attn_backend, (
"XPU only supports FLASH_ATTN for vision attention."
)
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for XPU: "
f"{self.attn_backend}."
)

View File

@@ -89,6 +89,7 @@ class XPUPlatform(Platform):
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
]