Move query quantization to attention layer for Flashinfer & Triton. (#26534)

Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: Adrian Abeyta <aabeyta@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Adrian Abeyta
2025-10-15 18:01:38 -05:00
committed by GitHub
parent e5b438a247
commit 0a9ef0cfce
6 changed files with 43 additions and 38 deletions

View File

@@ -32,11 +32,6 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.kv_cache_interface import AttentionSpec
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
@@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def supports_quant_query_input(self) -> bool:
return current_platform.is_cuda()
def __init__(
self,
num_heads: int,
@@ -338,19 +336,9 @@ class TritonAttentionImpl(AttentionImpl):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens