[XPU] support Triton Attention backend on Intel GPU (#24149)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -7,7 +7,6 @@ from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
@@ -23,6 +22,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -337,7 +341,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@@ -354,9 +358,10 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
assert layer._q_scale == 1.0, \
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
if not current_platform.is_rocm():
|
||||
# Skip Q quantization on ROCm, since dequantizing back to
|
||||
# f32 in the attention kernel is not supported.
|
||||
if current_platform.is_cuda():
|
||||
# Skip Q quantization on ROCm and XPU, enable this on cuda
|
||||
# only, since dequantizing back to f32 in the attention kernel
|
||||
# is not supported.
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
|
||||
Reference in New Issue
Block a user