[XPU]fallback to TRITON_ATTN on xpu when use float32 dtype (#31762)

Signed-off-by: sihao.li <sihao.li@intel.com>
This commit is contained in:
sihao_li
2026-01-07 16:10:29 +08:00
committed by GitHub
parent e7596371a4
commit 59fe6f298e

View File

@@ -52,11 +52,18 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
dtype = attn_selector_config.dtype
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
elif dtype == torch.float32:
logger.warning_once(
"Flash Attention on XPU does not support float32 dtype. "
"Falling back to Triton Attention backend."
)
return AttentionBackendEnum.TRITON_ATTN.get_path()
elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.")
return AttentionBackendEnum.FLASH_ATTN.get_path()