[XPU] support Triton Attention backend on Intel GPU (#24149)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2025-09-04 20:41:08 +08:00
committed by GitHub
parent 2b30afa442
commit 16ded21eeb
5 changed files with 49 additions and 15 deletions

View File

@@ -7,7 +7,6 @@ from typing import ClassVar, Optional
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
@@ -23,6 +22,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
@@ -337,7 +341,7 @@ class TritonAttentionImpl(AttentionImpl):
layer._v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
ops.reshape_and_cache_flash(
key,
value,
key_cache,
@@ -354,9 +358,10 @@ class TritonAttentionImpl(AttentionImpl):
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \
"A non 1.0 q_scale is not currently supported."
if not current_platform.is_rocm():
# Skip Q quantization on ROCm, since dequantizing back to
# f32 in the attention kernel is not supported.
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),