diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 1e9c714ea..e59806abb 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -249,7 +249,14 @@ class MMEncoderAttention(CustomOp): cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: - assert self.is_flash_attn_backend, ( - "XPU only supports FLASH_ATTN for vision attention." - ) - return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) + if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: + return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) + elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN: + return self._forward_triton(query, key, value, cu_seqlens, max_seqlen) + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: + return self._forward_sdpa(query, key, value, cu_seqlens) + else: + raise ValueError( + f"Unsupported multi-modal encoder attention backend for XPU: " + f"{self.attn_backend}." + ) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 8daa2d47f..5ce3cfba8 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -89,6 +89,7 @@ class XPUPlatform(Platform): def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: return [ AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.TORCH_SDPA, ]