Move query quantization to attention layer for Flashinfer & Triton. (#26534)
Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -32,11 +32,6 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
return current_platform.is_cuda()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -338,19 +336,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
# Skip Q quantization on ROCm and XPU, enable this on cuda
|
||||
# only, since dequantizing back to f32 in the attention kernel
|
||||
# is not supported.
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale,
|
||||
)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
|
||||
Reference in New Issue
Block a user