[XPU] allow TORCH_SDPA/TRITON_ATTN as XPU vit Backend (#35010)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
@@ -249,7 +249,14 @@ class MMEncoderAttention(CustomOp):
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.is_flash_attn_backend, (
|
||||
"XPU only supports FLASH_ATTN for vision attention."
|
||||
)
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported multi-modal encoder attention backend for XPU: "
|
||||
f"{self.attn_backend}."
|
||||
)
|
||||
|
||||
@@ -89,6 +89,7 @@ class XPUPlatform(Platform):
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user