[XPU]fallback to TRITON_ATTN on xpu when use float32 dtype (#31762)
Signed-off-by: sihao.li <sihao.li@intel.com>
This commit is contained in:
@@ -52,11 +52,18 @@ class XPUPlatform(Platform):
|
||||
"only NHD layout is supported by XPU attention kernels."
|
||||
)
|
||||
|
||||
dtype = attn_selector_config.dtype
|
||||
if attn_selector_config.use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
elif dtype == torch.float32:
|
||||
logger.warning_once(
|
||||
"Flash Attention on XPU does not support float32 dtype. "
|
||||
"Falling back to Triton Attention backend."
|
||||
)
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend.")
|
||||
return AttentionBackendEnum.FLASH_ATTN.get_path()
|
||||
|
||||
Reference in New Issue
Block a user