[Feature] KV cache per-token-head INT8/FP8 quantization (#38378)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: yangyang4991 <yangyang4991@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
JartX
2026-04-02 14:13:26 +02:00
committed by GitHub
parent 4eefbf9609
commit 2ce3d0ce36
16 changed files with 1308 additions and 66 deletions

View File

@@ -177,7 +177,7 @@ Priority is **1 = highest** (tried first).
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `int8_per_token_head`, `fp8_per_token_head` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
>

View File

@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""End-to-end accuracy tests for per-token-head KV cache quantization.
Compares logprobs between a baseline bf16 model and the same model with
per-token-head quantized KV cache (int8 or fp8) using the Triton attention
backend.
Run: pytest tests/models/quantization/test_per_token_kv_cache.py -v -s
"""
import pytest
from vllm.platforms import current_platform
from ..utils import check_logprobs_close
@pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="Per-token-head KV cache requires CUDA or ROCm GPU.",
)
@pytest.mark.parametrize(
"base_model,test_model",
[
(
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct",
),
],
)
@pytest.mark.parametrize(
"kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"]
)
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("backend", ["TRITON_ATTN"])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_per_token_head_kv_cache_accuracy(
vllm_runner,
example_prompts,
base_model: str,
test_model: str,
kv_cache_dtype: str,
max_tokens: int,
enforce_eager: bool,
backend: str,
tensor_parallel_size: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Compare logprobs between bf16 baseline and per-token-head quantized KV
cache.
Uses calculate_kv_scales (dynamic scale computation) since there are
no per-token-head calibrated checkpoints available yet.
"""
with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true")
MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8
with vllm_runner(
base_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype="auto",
attention_config={"backend": backend},
) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)
with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
calculate_kv_scales=True,
attention_config={"backend": backend},
) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)
check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=test_outputs,
name_0="bf16_kv_cache",
name_1=f"{kv_cache_dtype}_kv_cache",
)

View File

@@ -0,0 +1,560 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for per-token-head KV cache quantization (INT8 and FP8).
Covers:
- Per-token-head Triton reshape-and-cache kernel
- Round-trip quantize/dequantize accuracy
- process_weights_after_loading early-return path
- End-to-end integration with Triton unified attention kernel
Run: pytest tests/quantization/test_per_token_kv_cache.py -v -s
"""
import random
from dataclasses import dataclass
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.kv_cache_interface import KVQuantMode, is_quantized_kv_cache
# Skip entire module if no CUDA/ROCm GPU available
pytestmark = [
pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="Per-token-head KV cache tests require CUDA or ROCm GPU.",
),
]
# ---------------------------------------------------------------------------
# Test parameters
# ---------------------------------------------------------------------------
NUM_TOKENS = [1, 7, 42]
NUM_KV_HEADS = [1, 4, 8]
HEAD_SIZES = [64, 128]
BLOCK_SIZES = [16]
SEEDS = [0]
# Platform-dependent FP8 dtype and range
FP8_DTYPE = current_platform.fp8_dtype()
FP8_MIN, FP8_MAX = get_fp8_min_max()
# ---------------------------------------------------------------------------
# Per-dtype quantization config
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class QuantConfig:
"""Quantization parameters for a given cache dtype."""
cache_dtype: torch.dtype # torch.int8 or FP8_DTYPE
kv_cache_dtype_str: str # "int8_per_token_head" or "fp8_per_token_head"
quant_max: float
quant_min: float
kv_quant_mode: KVQuantMode
# INT8 Triton stores truncate; FP8 hardware casts round.
uses_trunc: bool
INT8_CONFIG = QuantConfig(
cache_dtype=torch.int8,
kv_cache_dtype_str="int8_per_token_head",
quant_max=127.0,
quant_min=-128.0,
kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD,
uses_trunc=True,
)
FP8_CONFIG = QuantConfig(
cache_dtype=FP8_DTYPE,
kv_cache_dtype_str="fp8_per_token_head",
quant_max=FP8_MAX,
quant_min=FP8_MIN,
kv_quant_mode=KVQuantMode.FP8_PER_TOKEN_HEAD,
uses_trunc=False,
)
QUANT_CONFIGS = [INT8_CONFIG, FP8_CONFIG]
@pytest.fixture(params=QUANT_CONFIGS, ids=["int8", "fp8"])
def qcfg(request) -> QuantConfig:
return request.param
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _quantize_per_token_head_ref(
data: torch.Tensor, # [num_tokens, num_heads, head_size]
cfg: QuantConfig,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Reference per-token-head quantization (one scale per token per head).
Returns (quantized, scales) where scales is [num_tokens, num_heads].
"""
absmax = data.float().abs().amax(dim=2) # [num_tokens, num_heads]
scales = (absmax / cfg.quant_max).clamp(min=1e-6)
scaled = data.float() * (1.0 / scales[:, :, None])
if cfg.uses_trunc:
q = scaled.round().clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype)
else:
q = scaled.clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype)
return q, scales
# ===========================================================================
# 1. is_quantized_kv_cache / get_kv_quant_mode
# ===========================================================================
class TestIsQuantizedKvCache:
def test_fp8_variants(self):
assert is_quantized_kv_cache("fp8")
assert is_quantized_kv_cache("fp8_e4m3")
assert is_quantized_kv_cache("fp8_e5m2")
def test_int8_per_token_head(self):
assert is_quantized_kv_cache("int8_per_token_head")
def test_fp8_per_token_head(self):
assert is_quantized_kv_cache("fp8_per_token_head")
def test_auto(self):
assert not is_quantized_kv_cache("auto")
def test_bfloat16(self):
assert not is_quantized_kv_cache("bfloat16")
def test_kv_quant_mode_int8(self):
from vllm.v1.kv_cache_interface import get_kv_quant_mode
assert (
get_kv_quant_mode("int8_per_token_head") == KVQuantMode.INT8_PER_TOKEN_HEAD
)
def test_kv_quant_mode_fp8(self):
from vllm.v1.kv_cache_interface import get_kv_quant_mode
assert get_kv_quant_mode("fp8_per_token_head") == KVQuantMode.FP8_PER_TOKEN_HEAD
# ===========================================================================
# 2. Triton per-token-head kernel (reshape-and-cache)
# ===========================================================================
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_reshape_and_cache_per_token_head(
qcfg: QuantConfig,
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
seed: int,
):
"""Test triton_reshape_and_cache_flash_per_token_head_quant kernel."""
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_per_token_head_quant,
)
set_random_seed(seed)
torch.set_default_device("cuda")
num_blocks = (num_tokens + block_size - 1) // block_size + 4
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
key_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
value_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
num_slots = block_size * num_blocks
slot_mapping = torch.tensor(
random.sample(range(num_slots), num_tokens), dtype=torch.long
)
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
k_scale_cache,
v_scale_cache,
slot_mapping,
)
# Reference
ref_k_quant, ref_k_scales = _quantize_per_token_head_ref(key, qcfg)
ref_v_quant, ref_v_scales = _quantize_per_token_head_ref(value, qcfg)
# Compare dequantized values rather than raw quantized values.
# Triton and PyTorch reductions can differ at FP8 rounding boundaries
# (up to 32 in quantized domain for fp8_e4m3), but the dequantized
# error is bounded by the scale.
for i, slot in enumerate(slot_mapping.tolist()):
blk = slot // block_size
off = slot % block_size
actual_k_scale = k_scale_cache[blk, off] # [num_heads]
k_deq = key_cache[blk, off].float() * actual_k_scale[:, None]
k_ref_deq = key[i].float()
torch.testing.assert_close(
k_deq,
k_ref_deq,
atol=0.1,
rtol=0.1,
)
actual_v_scale = v_scale_cache[blk, off] # [num_heads]
v_deq = value_cache[blk, off].float() * actual_v_scale[:, None]
v_ref_deq = value[i].float()
torch.testing.assert_close(
v_deq,
v_ref_deq,
atol=0.1,
rtol=0.1,
)
# Per-head scales: [num_heads]
torch.testing.assert_close(
k_scale_cache[blk, off], ref_k_scales[i], atol=1e-4, rtol=1e-3
)
torch.testing.assert_close(
v_scale_cache[blk, off], ref_v_scales[i], atol=1e-4, rtol=1e-3
)
# ===========================================================================
# 3. Per-token-head round-trip accuracy (quantize -> dequantize)
# ===========================================================================
@pytest.mark.parametrize("num_tokens", [1, 16])
@pytest.mark.parametrize("num_heads", [4])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("block_size", [16])
@torch.inference_mode()
def test_per_token_head_round_trip_accuracy(
qcfg: QuantConfig,
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
):
"""Verify per-token-head round-trip: kernel dequant matches reference.
INT8: Triton truncates on float->int8 store.
FP8: hardware cast (clamp then cast).
"""
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_per_token_head_quant,
)
torch.set_default_device("cuda")
set_random_seed(42)
num_blocks = (num_tokens + block_size - 1) // block_size + 2
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5
key_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
value_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
slot_mapping = torch.arange(num_tokens, dtype=torch.long)
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
k_scale_cache,
v_scale_cache,
slot_mapping,
)
for i in range(num_tokens):
blk = i // block_size
off = i % block_size
for label, data, cache, sc in [
("key", key, key_cache, k_scale_cache),
("val", value, value_cache, v_scale_cache),
]:
for h in range(num_heads):
orig = data[i, h].float() # [head_size]
actual_q = cache[blk, off, h]
actual_sc = sc[blk, off, h]
actual_deq = actual_q.float() * actual_sc
# Round-trip: dequantized should be close to original
torch.testing.assert_close(
actual_deq,
orig,
atol=0.1,
rtol=0.1,
)
# ===========================================================================
# 4. Negative slot mapping (padding tokens should be skipped)
# ===========================================================================
@torch.inference_mode()
def test_per_token_head_negative_slot_skipped(qcfg: QuantConfig):
"""Tokens with slot_mapping=-1 should leave the cache unchanged."""
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_per_token_head_quant,
)
torch.set_default_device("cuda")
num_tokens = 4
num_heads = 2
head_size = 64
block_size = 16
num_blocks = 2
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
key_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
value_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
slot_mapping = torch.tensor([0, -1, 1, -1], dtype=torch.long)
key_cache_before = key_cache.clone()
val_cache_before = value_cache.clone()
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
k_scale_cache,
v_scale_cache,
slot_mapping,
)
# Slots 0 and 1 should have been written (tokens 0 and 2)
assert not torch.equal(key_cache[0, 0], key_cache_before[0, 0])
assert not torch.equal(key_cache[0, 1], key_cache_before[0, 1])
assert not torch.equal(value_cache[0, 0], val_cache_before[0, 0])
# All other slots should be unchanged
assert torch.equal(key_cache[0, 2:], key_cache_before[0, 2:])
assert torch.equal(key_cache[1], key_cache_before[1])
assert torch.equal(value_cache[0, 2:], val_cache_before[0, 2:])
# ===========================================================================
# 5. process_weights_after_loading -- per-token-head early return
# ===========================================================================
@pytest.mark.parametrize(
"kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"]
)
def test_process_weights_sets_placeholder_scales(kv_cache_dtype: str):
"""Per-token-head should set _k_scale=1.0, _v_scale=1.0
and delete checkpoint attrs."""
from vllm.model_executor.layers.quantization.kv_cache import (
BaseKVCacheMethod,
)
layer = MagicMock()
layer.kv_cache_dtype = kv_cache_dtype
layer.calculate_kv_scales = False
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer._k_scale = torch.tensor(0.0)
layer._v_scale = torch.tensor(0.0)
layer._k_scale_float = 0.0
layer._v_scale_float = 0.0
method = BaseKVCacheMethod.__new__(BaseKVCacheMethod)
method.quant_config = MagicMock()
method.process_weights_after_loading(layer)
assert layer._k_scale_float == 1.0
assert layer._v_scale_float == 1.0
assert not hasattr(layer, "k_scale")
assert not hasattr(layer, "v_scale")
assert not hasattr(layer, "q_scale")
assert not hasattr(layer, "prob_scale")
# ===========================================================================
# 6. Triton unified_attention -- per-token-head scale cache (INT8 and FP8)
# ===========================================================================
@pytest.mark.parametrize(
"seq_lens",
[
[(1, 128)],
[(1, 64), (1, 32)],
],
)
@pytest.mark.parametrize("num_heads", [(4, 4)])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("block_size", [16])
@torch.inference_mode()
def test_triton_unified_attention_per_token_head_scale(
qcfg: QuantConfig,
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
block_size: int,
):
"""End-to-end: quantized KV with per-token-head scale caches."""
from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
torch.set_default_device("cuda")
set_random_seed(0)
num_seqs = len(seq_lens)
query_lens = [s[0] for s in seq_lens]
kv_lens = [s[1] for s in seq_lens]
num_query_heads, num_kv_heads = num_heads
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
scale = head_size**-0.5
num_blocks = 2048
query = torch.randn(
sum(query_lens), num_query_heads, head_size, dtype=torch.bfloat16
)
key_cache_bf16 = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=torch.bfloat16
)
value_cache_bf16 = torch.randn_like(key_cache_bf16)
# Per-token-head quantization: one scale per (block, slot, head)
k_absmax = key_cache_bf16.float().abs().amax(dim=-1) # [..., num_kv_heads]
v_absmax = value_cache_bf16.float().abs().amax(dim=-1)
k_scale_cache = (k_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32)
v_scale_cache = (v_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32)
scaled_k = key_cache_bf16.float() / k_scale_cache[:, :, :, None]
scaled_v = value_cache_bf16.float() / v_scale_cache[:, :, :, None]
if qcfg.uses_trunc:
key_cache_q = (
scaled_k.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype)
)
value_cache_q = (
scaled_v.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype)
)
else:
key_cache_q = scaled_k.clamp(qcfg.quant_min, qcfg.quant_max).to(
qcfg.cache_dtype
)
value_cache_q = scaled_v.clamp(qcfg.quant_min, qcfg.quant_max).to(
qcfg.cache_dtype
)
# Dequantized reference
key_cache_deq = key_cache_q.float() * k_scale_cache[:, :, :, None]
value_cache_deq = value_cache_q.float() * v_scale_cache[:, :, :, None]
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens_t = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
head_size_padded = next_power_of_2(head_size)
seq_threshold_3D = 0
num_par_softmax_segments = 16
softmax_segm_output = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
dtype=torch.float32,
)
softmax_segm_max = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
softmax_segm_expsum = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
output_q = torch.empty_like(query)
unified_attention(
q=query,
k=key_cache_q,
v=value_cache_q,
out=output_q,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens_t,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=(-1, -1),
block_table=block_tables,
softcap=0,
q_descale=None,
k_descale=None,
v_descale=None,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
kv_quant_mode=qcfg.kv_quant_mode,
k_scale_cache=k_scale_cache,
v_scale_cache=v_scale_cache,
)
output_ref = torch.empty_like(query)
unified_attention(
q=query,
k=key_cache_deq.to(torch.bfloat16),
v=value_cache_deq.to(torch.bfloat16),
out=output_ref,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens_t,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=(-1, -1),
block_table=block_tables,
softcap=0,
q_descale=None,
k_descale=None,
v_descale=None,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
)
torch.testing.assert_close(output_q, output_ref, atol=5e-2, rtol=5e-2)

View File

@@ -8,7 +8,10 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.utils.torch_utils import (
is_quantized_kv_cache,
kv_cache_uses_per_token_head_scales,
)
logger = init_logger(__name__)
@@ -21,6 +24,8 @@ CacheDType = Literal[
"fp8_e5m2",
"fp8_inc",
"fp8_ds_mla",
"int8_per_token_head",
"fp8_per_token_head",
]
MambaDType = Literal["auto", "float32", "float16"]
MambaCacheMode = Literal["all", "align", "none"]
@@ -237,12 +242,20 @@ class CacheConfig:
@field_validator("cache_dtype", mode="after")
@classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
if is_quantized_kv_cache(cache_dtype):
if kv_cache_uses_per_token_head_scales(cache_dtype):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"Using %s data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Dynamic per-token-head scales will be computed at runtime.",
str(cache_dtype),
)
elif is_quantized_kv_cache(cache_dtype):
logger.info(
"Using %s data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor."
"scaling factor",
str(cache_dtype),
)
return cache_dtype

View File

@@ -38,6 +38,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheSpec,
SlidingWindowSpec,
get_kv_quant_mode,
)
if TYPE_CHECKING:
@@ -381,8 +382,10 @@ class Attention(nn.Module, AttentionLayerBase):
# for attn backends supporting query quantization
self.query_quant = None
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
"fp8"
if (
self.impl.supports_quant_query_input
and self.kv_cache_dtype.startswith("fp8")
and not self.kv_cache_dtype.endswith("per_token_head")
):
is_per_head = (
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
@@ -539,6 +542,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
quant_mode = get_kv_quant_mode(self.kv_cache_dtype)
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
@@ -548,6 +552,7 @@ class Attention(nn.Module, AttentionLayerBase):
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
kv_quant_mode=quant_mode,
sliding_window=self.sliding_window,
)
else:
@@ -557,6 +562,7 @@ class Attention(nn.Module, AttentionLayerBase):
head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype,
kv_quant_mode=quant_mode,
)

View File

@@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import (
AttentionSpec,
ChunkedLocalAttentionSpec,
KVCacheSpec,
get_kv_quant_mode,
)
@@ -123,5 +124,6 @@ class ChunkedLocalAttention(Attention):
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
attention_chunk_size=self.attention_chunk_size,
)

View File

@@ -18,7 +18,11 @@ from vllm.v1.attention.backend import (
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
from vllm.v1.kv_cache_interface import (
CrossAttentionSpec,
KVCacheSpec,
get_kv_quant_mode,
)
logger = init_logger(__name__)
@@ -220,4 +224,5 @@ class CrossAttention(Attention):
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
)

View File

@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
SinkFullAttentionSpec,
get_kv_quant_mode,
)
logger = init_logger(__name__)
@@ -217,6 +218,7 @@ class StaticSinkAttention(Attention, CustomOp):
head_size_v=self.head_size_v,
sink_len=self.sink_len,
dtype=self.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
)

View File

@@ -10,6 +10,7 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales
logger = init_logger(__name__)
@@ -53,6 +54,20 @@ class BaseKVCacheMethod(QuantizeMethodBase):
assert not hasattr(layer, "prob_scale")
return
# Per-token-head quantized KV cache: scales are computed dynamically
# per (token, head) in the kernel at cache-write time. Checkpoint
# scales are never used regardless of calculate_kv_scales.
if kv_cache_uses_per_token_head_scales(layer.kv_cache_dtype):
layer._k_scale.copy_(1.0)
layer._v_scale.copy_(1.0)
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.prob_scale
return
# If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to

View File

@@ -505,6 +505,7 @@ class Platform:
FullAttentionSpec,
MambaSpec,
MLAAttentionSpec,
get_kv_quant_mode,
)
cache_config = vllm_config.cache_config
@@ -516,6 +517,8 @@ class Platform:
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
kv_quant_mode = get_kv_quant_mode(cache_config.cache_dtype)
# Compute attention page size for 1 token
if model_config.use_mla:
attn_page_size_1_token = MLAAttentionSpec(
@@ -523,6 +526,7 @@ class Platform:
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
kv_quant_mode=kv_quant_mode,
).page_size_bytes
else:
attn_page_size_1_token = FullAttentionSpec(
@@ -530,6 +534,7 @@ class Platform:
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
kv_quant_mode=kv_quant_mode,
).page_size_bytes
# Compute mamba page size

View File

@@ -37,6 +37,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
"int8_per_token_head": torch.int8,
"fp8_per_token_head": torch.uint8,
"fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8,
}
@@ -62,7 +64,12 @@ T = TypeVar("T")
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")
return kv_cache_dtype.startswith("fp8") or kv_cache_dtype.endswith("per_token_head")
def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
"""Return True if *kv_cache_dtype* needs per-token-head scales."""
return kv_cache_dtype.endswith("per_token_head")
def is_strictly_contiguous(t: torch.Tensor) -> bool:

View File

@@ -17,7 +17,9 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode
from vllm.v1.kv_cache_interface import get_kv_quant_mode
class AttentionType(str, Enum):
@@ -740,6 +742,13 @@ class AttentionImplBase(ABC, Generic[T]):
class AttentionImpl(AttentionImplBase[T], Generic[T]):
"""Standard attention implementation with forward method."""
kv_cache_dtype: str
@property
def kv_quant_mode(self) -> "KVQuantMode":
"""Return the KV cache quantization mode for this layer."""
return get_kv_quant_mode(self.kv_cache_dtype)
@abstractmethod
def __init__(
self,

View File

@@ -33,9 +33,14 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
triton_reshape_and_cache_flash_per_token_head_quant,
)
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import (
AttentionSpec,
get_kv_quant_mode,
kv_cache_uses_per_token_head_scales,
)
logger = init_logger(__name__)
@@ -270,6 +275,8 @@ class TritonAttentionBackend(AttentionBackend):
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"int8_per_token_head",
"fp8_per_token_head",
]
@staticmethod
@@ -302,6 +309,18 @@ class TritonAttentionBackend(AttentionBackend):
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
if kv_cache_uses_per_token_head_scales(cache_dtype_str):
# Pad head_size by sizeof(float32)/sizeof(cache_dtype) so
# the per-head scale fits inline. The backend extracts
# data[:head_size] and scale[head_size:] via typed views.
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_dtype_size,
)
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype_str]
scale_pad = get_dtype_size(torch.float32) // get_dtype_size(cache_dtype)
return (num_blocks, 2, block_size, num_kv_heads, head_size + scale_pad)
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
@@ -365,6 +384,62 @@ class TritonAttentionBackend(AttentionBackend):
class TritonAttentionImpl(AttentionImpl):
# Per-token-head quant: scale views carved from inline head padding.
_k_scale_cache: torch.Tensor | None = None
_v_scale_cache: torch.Tensor | None = None
def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None:
"""Extract per-head scale views from the padded head dimension.
The KV cache shape is ``(num_blocks, 2, block_size, nkv, hs+pad)``
where ``pad = sizeof(float32) / sizeof(cache_dtype)``. The last
``pad`` elements of each head hold one float32 scale. We create
strided float32 views over those bytes.
Scale shape: ``(num_blocks, block_size, num_kv_heads)``
"""
if self._k_scale_cache is not None:
return
from vllm.utils.torch_utils import get_dtype_size
num_blocks, _, block_size, nkv, padded_hs = kv_cache.shape
dtype_sz = kv_cache.element_size()
scale_pad = get_dtype_size(torch.float32) // dtype_sz # e.g. 4
hs = padded_hs - scale_pad
raw = kv_cache.untyped_storage()
base_f32 = torch.tensor([], dtype=torch.float32, device=kv_cache.device).set_(
raw
)
# In the raw bytes, each (block, kv_half, slot, head) occupies
# padded_hs * dtype_sz bytes. The scale float32 sits at byte
# offset hs * dtype_sz within that region.
kv_half_bytes = block_size * nkv * padded_hs * dtype_sz
full_block_f32 = 2 * kv_half_bytes // 4 # stride between blocks
slot_f32 = nkv * padded_hs * dtype_sz // 4 # stride between slots
head_f32 = padded_hs * dtype_sz // 4 # stride between heads
scale_off_f32 = hs * dtype_sz // 4 # offset to scale within head
# K scales: kv_half=0
self._k_scale_cache = torch.as_strided(
base_f32,
size=(num_blocks, block_size, nkv),
stride=(full_block_f32, slot_f32, head_f32),
storage_offset=scale_off_f32,
)
self._k_scale_cache.fill_(1.0)
# V scales: kv_half=1, offset by kv_half_bytes
v_base_f32 = kv_half_bytes // 4
self._v_scale_cache = torch.as_strided(
base_f32,
size=(num_blocks, block_size, nkv),
stride=(full_block_f32, slot_f32, head_f32),
storage_offset=v_base_f32 + scale_off_f32,
)
self._v_scale_cache.fill_(1.0)
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
@@ -418,6 +493,9 @@ class TritonAttentionImpl(AttentionImpl):
self.use_alibi_sqrt = use_alibi_sqrt
self.supports_quant_query_input = current_platform.is_cuda()
self._kv_quant_mode = get_kv_quant_mode(kv_cache_dtype)
self._is_per_token_head_quant = self._kv_quant_mode.is_per_token_head
def forward(
self,
layer: torch.nn.Module,
@@ -480,15 +558,35 @@ class TritonAttentionImpl(AttentionImpl):
layer,
)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if is_quantized_kv_cache(self.kv_cache_dtype):
if key_cache.dtype != self.fp8_dtype:
# Per-token-head quantized KV cache: use separate scale caches.
if self._is_per_token_head_quant:
self._ensure_scale_caches(kv_cache)
key_cache, value_cache = kv_cache.unbind(1)
if key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
k_descale = None
v_descale = None
k_scale_cache = self._k_scale_cache
v_scale_cache = self._v_scale_cache
# FP8 per-tensor / auto path (original flow).
else:
key_cache, value_cache = kv_cache.unbind(1)
if is_quantized_kv_cache(self.kv_cache_dtype):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
descale_shape = (
attn_metadata.query_start_loc.shape[0] - 1,
key_cache.shape[2],
)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
k_scale_cache = None
v_scale_cache = None
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
@@ -502,7 +600,6 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
unified_attention(
@@ -522,8 +619,8 @@ class TritonAttentionImpl(AttentionImpl):
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
k_descale=k_descale,
v_descale=v_descale,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
@@ -532,6 +629,9 @@ class TritonAttentionImpl(AttentionImpl):
sinks=self.sinks,
output_scale=output_scale,
mm_prefix_range=mm_prefix_range_tensor,
kv_quant_mode=self._kv_quant_mode,
k_scale_cache=k_scale_cache,
v_scale_cache=v_scale_cache,
)
return output
@@ -555,10 +655,10 @@ class TritonAttentionImpl(AttentionImpl):
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
# Quantized KV cache is not supported for encoder attention.
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"quantization is not supported for encoder attention"
"quantized KV cache is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
@@ -594,16 +694,28 @@ class TritonAttentionImpl(AttentionImpl):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
# Reshape the input keys and values and store them in the cache.
if self._is_per_token_head_quant:
self._ensure_scale_caches(kv_cache)
key_cache, value_cache = kv_cache.unbind(1)
if key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
self._k_scale_cache,
self._v_scale_cache,
slot_mapping,
)
return
# For decoder and cross-attention, use KV cache as before.
key_cache, value_cache = kv_cache.unbind(1)
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
@@ -616,6 +728,8 @@ class TritonAttentionImpl(AttentionImpl):
)
def fused_rope_kvcache_supported(self):
if self._is_per_token_head_quant:
return False
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(

View File

@@ -3,10 +3,16 @@
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FP8_DTYPE,
get_fp8_min_max,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_quantized_kv_cache
FP8_MIN, FP8_MAX = get_fp8_min_max()
@triton.jit
def reshape_and_cache_kernel_flash(
@@ -118,6 +124,198 @@ def reshape_and_cache_kernel_flash(
return
# ---------------------------------------------------------------------------
# Per-token-head dynamic quantization kernel
# Grid: (num_tokens, NUM_KV_HEADS)
# Each program handles one (token, head) pair:
# 1. Loads K (or V) for that single head
# 2. Computes absmax across head_size → scale = absmax / QUANT_MAX
# 3. Quantizes and stores the data + per-head scale
#
# Parametrised by QUANT_MAX / QUANT_MIN so the same code path works
# for int8 (±127/128), fp8_e4m3 (±448), and other formats.
# ---------------------------------------------------------------------------
@triton.jit
def _reshape_cache_per_token_head(
key_ptr, # [num_tokens, num_kv_heads, head_size]
value_ptr, # [num_tokens, num_kv_heads, head_size_v]
key_cache_ptr, # [num_blocks, block_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blocks, block_size, num_kv_heads, head_size_v]
k_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32
v_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32
slot_mapping_ptr, # [num_tokens]
stride_key_tok: tl.int64,
stride_key_head: tl.int64,
stride_val_tok: tl.int64,
stride_val_head: tl.int64,
stride_kc_blk: tl.int64, # key_cache stride over blocks
stride_kc_slot: tl.int64, # key_cache stride over slots
stride_kc_head: tl.int64, # key_cache stride over heads
stride_vc_blk: tl.int64,
stride_vc_slot: tl.int64,
stride_vc_head: tl.int64,
stride_ks_blk: tl.int64, # k_scale_cache stride[0] (blocks)
stride_ks_slot: tl.int64, # k_scale_cache stride[1] (slots)
stride_ks_head: tl.int64, # k_scale_cache stride[2] (heads)
stride_vs_blk: tl.int64, # v_scale_cache stride[0] (blocks)
stride_vs_slot: tl.int64, # v_scale_cache stride[1] (slots)
stride_vs_head: tl.int64, # v_scale_cache stride[2] (heads)
block_size: tl.constexpr,
head_size: tl.constexpr,
head_size_v: tl.constexpr,
HEAD_SIZE_PADDED: tl.constexpr, # next_power_of_2(max(head_size, head_size_v))
QUANT_MAX: tl.constexpr = 127.0,
QUANT_MIN: tl.constexpr = -128.0,
):
tok = tl.program_id(0)
head = tl.program_id(1)
slot = tl.load(slot_mapping_ptr + tok).to(tl.int64)
if slot < 0:
return
blk = slot // block_size
slot_in_blk = slot % block_size
dim_offs = tl.arange(0, HEAD_SIZE_PADDED)
# ---- Key: load one head → absmax → quantize → store -------------------
k_mask = dim_offs < head_size
k_h = tl.load(
key_ptr + tok * stride_key_tok + head * stride_key_head + dim_offs,
mask=k_mask,
other=0.0,
).to(tl.float32)
k_scale = tl.maximum(tl.max(tl.abs(k_h)) / QUANT_MAX, 1e-6)
tl.store(
k_scale_cache_ptr
+ blk * stride_ks_blk
+ slot_in_blk * stride_ks_slot
+ head * stride_ks_head,
k_scale,
)
k_q = tl.clamp(k_h * (1.0 / k_scale), QUANT_MIN, QUANT_MAX)
tl.store(
key_cache_ptr
+ blk * stride_kc_blk
+ slot_in_blk * stride_kc_slot
+ head * stride_kc_head
+ dim_offs,
k_q,
mask=k_mask,
)
# ---- Value: same per-head approach ------------------------------------
v_mask = dim_offs < head_size_v
v_h = tl.load(
value_ptr + tok * stride_val_tok + head * stride_val_head + dim_offs,
mask=v_mask,
other=0.0,
).to(tl.float32)
v_scale = tl.maximum(tl.max(tl.abs(v_h)) / QUANT_MAX, 1e-6)
tl.store(
v_scale_cache_ptr
+ blk * stride_vs_blk
+ slot_in_blk * stride_vs_slot
+ head * stride_vs_head,
v_scale,
)
v_q = tl.clamp(v_h * (1.0 / v_scale), QUANT_MIN, QUANT_MAX)
tl.store(
value_cache_ptr
+ blk * stride_vc_blk
+ slot_in_blk * stride_vc_slot
+ head * stride_vc_head
+ dim_offs,
v_q,
mask=v_mask,
)
# Mapping from cache torch dtype to (QUANT_MAX, QUANT_MIN) for the
# per-token-head quantization kernel.
_PER_TOKEN_HEAD_QUANT_PARAMS: dict[torch.dtype, tuple[float, float]] = {
torch.int8: (127.0, -128.0),
FP8_DTYPE: (FP8_MAX, FP8_MIN),
}
def triton_reshape_and_cache_flash_per_token_head_quant(
key: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
value: torch.Tensor, # [num_tokens, num_kv_heads, head_size_v]
key_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size]
value_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size_v]
k_scale_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads] float32
v_scale_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads] float32
slot_mapping: torch.Tensor, # [num_tokens]
):
"""Quantize key/value per (token, head) and write to paged cache.
Computes one scale = absmax / QUANT_MAX per (token, head), stores
quantized data in key_cache/value_cache, and stores the float32
scale in k_scale_cache/v_scale_cache.
The quantization range (QUANT_MAX, QUANT_MIN) is derived from the
cache tensor dtype so the same code path works for int8 and fp8.
"""
cache_dtype = key_cache.dtype
quant_params = _PER_TOKEN_HEAD_QUANT_PARAMS.get(cache_dtype)
if quant_params is None:
raise ValueError(
f"Per-token-head quantization not supported for cache dtype "
f"{cache_dtype}. Supported: {list(_PER_TOKEN_HEAD_QUANT_PARAMS)}"
)
quant_max, quant_min = quant_params
num_tokens, num_kv_heads, head_size = key.shape
head_size_v = value.shape[2]
head_size_padded = triton.next_power_of_2(max(head_size, head_size_v))
block_size = key_cache.shape[1]
if current_platform.is_rocm() or current_platform.is_xpu():
num_warps = 4
else:
num_warps = min(16, max(1, head_size_padded // 32))
_reshape_cache_per_token_head[(num_tokens, num_kv_heads)](
key_ptr=key,
value_ptr=value,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
k_scale_cache_ptr=k_scale_cache,
v_scale_cache_ptr=v_scale_cache,
slot_mapping_ptr=slot_mapping,
stride_key_tok=key.stride(0),
stride_key_head=key.stride(1),
stride_val_tok=value.stride(0),
stride_val_head=value.stride(1),
stride_kc_blk=key_cache.stride(0),
stride_kc_slot=key_cache.stride(1),
stride_kc_head=key_cache.stride(2),
stride_vc_blk=value_cache.stride(0),
stride_vc_slot=value_cache.stride(1),
stride_vc_head=value_cache.stride(2),
stride_ks_blk=k_scale_cache.stride(0),
stride_ks_slot=k_scale_cache.stride(1),
stride_ks_head=k_scale_cache.stride(2),
stride_vs_blk=v_scale_cache.stride(0),
stride_vs_slot=v_scale_cache.stride(1),
stride_vs_head=v_scale_cache.stride(2),
block_size=block_size,
head_size=head_size,
head_size_v=head_size_v,
HEAD_SIZE_PADDED=head_size_padded,
QUANT_MAX=quant_max,
QUANT_MIN=quant_min,
num_warps=num_warps,
)
def triton_reshape_and_cache_flash(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size]
@@ -224,7 +422,6 @@ def triton_reshape_and_cache_flash(
block_size=block_size,
x=x,
USE_HEAD_MAJOR_LAYOUT=use_head_major_layout,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,

View File

@@ -13,6 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.kv_cache_interface import KVQuantMode
logger = init_logger(__name__)
is_batch_invariant = envs.VLLM_BATCH_INVARIANT
@@ -32,6 +33,63 @@ def apply_softcap(S, x):
return x * (p1 - p2) / (p1 + p2)
@triton.jit
def _prepare_kv_tile(
data,
Q,
tensor_scale,
scale_cache_ptr,
physical_block_idx,
seq_offset,
kv_head_idx,
stride_s_blk,
stride_s_slot,
stride_s_head,
tile_mask,
BLOCK_SIZE: tl.constexpr,
KV_QUANT_MODE: tl.constexpr,
):
"""Prepare a loaded KV tile for attention computation.
Casts the raw KV data to Q's dtype and loads per-token-head scales
when applicable:
- ``KV_QUANT_MODE == 0``: cast only (no-op for bf16/fp16).
- ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize inline
using the tensor-wide scale.
- ``KV_QUANT_MODE >= 2`` (per-token-head int8/fp8): cast to Q's
dtype and return per-head scales separately — the caller applies
them after the dot product for better numerical efficiency.
Returns ``(data, token_head_scales)``. *token_head_scales* is only
meaningful when ``KV_QUANT_MODE >= 2``; callers gate its use on
the same constexpr so the compiler eliminates dead code.
"""
# KV_QUANT_MODE values: 0=none, 1=fp8 per-tensor,
# 2=int8 per-token-head, 3=fp8 per-token-head
# Placeholder scales (float32) — never read when KV_QUANT_MODE < 2.
unused_scales = tile_mask.to(tl.float32)
if KV_QUANT_MODE == 1: # FP8 per-tensor
if Q.dtype.is_fp8():
return data.to(Q.dtype), unused_scales
return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype), unused_scales
if KV_QUANT_MODE >= 2: # per-token-head (int8 or fp8)
scale_idx = (
physical_block_idx * stride_s_blk
+ (seq_offset % BLOCK_SIZE) * stride_s_slot
+ kv_head_idx * stride_s_head
)
token_head_scales = tl.load(
scale_cache_ptr + scale_idx, mask=tile_mask, other=1.0
)
return data.to(Q.dtype), token_head_scales
# .to(Q.dtype) is a no-op when data is already Q's type (bf16/fp16),
# but required so Triton sees consistent return types across branches.
return data.to(Q.dtype), unused_scales
@triton.jit
def find_seq_idx(
query_start_len_ptr,
@@ -105,8 +163,20 @@ def kernel_unified_attention_2d(
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
# KV cache quantization: 0=none, 1=fp8, 2=per-token-head
KV_QUANT_MODE: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
# Per-token-head scale caches (KV_QUANT_MODE >= 2)
# Shape: [num_blocks, block_size, num_kv_heads]
k_scale_cache_ptr=None,
v_scale_cache_ptr=None,
stride_ks_blk=0,
stride_ks_slot=0,
stride_ks_head=0,
stride_vs_blk=0,
stride_vs_slot=0,
stride_vs_head=0,
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@@ -258,14 +328,21 @@ def kernel_unified_attention_2d(
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0,
)
if K_load.dtype.is_fp8():
if Q.dtype.is_fp8():
K = K_load
else:
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
K, k_token_head_scales = _prepare_kv_tile(
K_load,
Q,
k_scale,
k_scale_cache_ptr,
physical_block_idx,
seq_offset,
kv_head_idx,
stride_ks_blk,
stride_ks_slot,
stride_ks_head,
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
)
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
@@ -273,14 +350,21 @@ def kernel_unified_attention_2d(
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0,
)
if V_load.dtype.is_fp8():
if Q.dtype.is_fp8():
V = V_load
else:
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
V, v_token_head_scales = _prepare_kv_tile(
V_load,
Q,
v_scale,
v_scale_cache_ptr,
physical_block_idx,
seq_offset,
kv_head_idx,
stride_vs_blk,
stride_vs_slot,
stride_vs_head,
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
)
# Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None]
@@ -318,7 +402,12 @@ def kernel_unified_attention_2d(
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K)
# Per-token-head quant: fuse softmax_scale with per-head k_scale
# to avoid a separate BLOCK_M × TILE_SIZE multiply on S.
if KV_QUANT_MODE >= 2:
S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :])
else:
S += scale * tl.dot(Q, K)
if USE_SOFTCAP:
S = apply_softcap(S, softcap)
@@ -382,7 +471,12 @@ def kernel_unified_attention_2d(
)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)
# Per-token-head quant: apply v_scale to P instead of V.
if KV_QUANT_MODE >= 2:
P_v = (P * v_token_head_scales[None, :]).to(V.dtype)
acc += tl.dot(P_v, V)
else:
acc += tl.dot(P.to(V.dtype), V)
# epilogue
acc = acc / L[:, None]
@@ -453,6 +547,18 @@ def kernel_unified_attention_3d(
USE_MM_PREFIX: tl.constexpr, # bool
MAX_MM_RANGES: tl.constexpr, # int
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
# KV cache quantization: 0=none, 1=fp8, 2=per-token-head
KV_QUANT_MODE: tl.constexpr = 0,
# Per-token-head scale caches (KV_QUANT_MODE >= 2)
# Shape: [num_blocks, block_size, num_kv_heads]
k_scale_cache_ptr=None,
v_scale_cache_ptr=None,
stride_ks_blk=0,
stride_ks_slot=0,
stride_ks_head=0,
stride_vs_blk=0,
stride_vs_slot=0,
stride_vs_head=0,
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@@ -613,14 +719,21 @@ def kernel_unified_attention_3d(
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0,
)
if K_load.dtype.is_fp8():
if Q.dtype.is_fp8():
K = K_load
else:
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
K, k_token_head_scales = _prepare_kv_tile(
K_load,
Q,
k_scale,
k_scale_cache_ptr,
physical_block_idx,
seq_offset,
kv_head_idx,
stride_ks_blk,
stride_ks_slot,
stride_ks_head,
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
)
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
@@ -628,14 +741,21 @@ def kernel_unified_attention_3d(
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0,
)
if V_load.dtype.is_fp8():
if Q.dtype.is_fp8():
V = V_load
else:
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
V, v_token_head_scales = _prepare_kv_tile(
V_load,
Q,
v_scale,
v_scale_cache_ptr,
physical_block_idx,
seq_offset,
kv_head_idx,
stride_vs_blk,
stride_vs_slot,
stride_vs_head,
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
)
# Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None]
@@ -672,7 +792,13 @@ def kernel_unified_attention_3d(
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K)
# Per-token-head quant: fuse softmax_scale with per-head k_scale
# to avoid a separate BLOCK_M × TILE_SIZE multiply on S.
if KV_QUANT_MODE >= 2:
S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :])
else:
S += scale * tl.dot(Q, K)
if USE_SOFTCAP:
S = apply_softcap(S, softcap)
@@ -736,7 +862,12 @@ def kernel_unified_attention_3d(
)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)
# Per-token-head quant: apply v_scale to P instead of V.
if KV_QUANT_MODE >= 2:
P_v = (P * v_token_head_scales[None, :]).to(V.dtype)
acc += tl.dot(P_v, V)
else:
acc += tl.dot(P.to(V.dtype), V)
segm_output_offset = (
query_offset_0[:, None].to(tl.int64)
@@ -911,6 +1042,10 @@ def unified_attention(
# Optional tensor for prefix lengths (PrefixLM support)
mm_prefix_range=None,
use_alibi_sqrt=False,
# KV cache quantization mode and per-token-head scale caches.
kv_quant_mode: KVQuantMode = KVQuantMode.NONE,
k_scale_cache=None, # [num_blocks, block_size, num_kv_heads] float32
v_scale_cache=None, # [num_blocks, block_size, num_kv_heads] float32
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
@@ -1040,6 +1175,15 @@ def unified_attention(
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
KV_QUANT_MODE=kv_quant_mode,
k_scale_cache_ptr=k_scale_cache,
v_scale_cache_ptr=v_scale_cache,
stride_ks_blk=k_scale_cache.stride(0) if k_scale_cache is not None else 0,
stride_ks_slot=k_scale_cache.stride(1) if k_scale_cache is not None else 0,
stride_ks_head=k_scale_cache.stride(2) if k_scale_cache is not None else 0,
stride_vs_blk=v_scale_cache.stride(0) if v_scale_cache is not None else 0,
stride_vs_slot=v_scale_cache.stride(1) if v_scale_cache is not None else 0,
stride_vs_head=v_scale_cache.stride(2) if v_scale_cache is not None else 0,
)
else:
kernel_unified_attention_3d[
@@ -1092,6 +1236,15 @@ def unified_attention(
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
KV_QUANT_MODE=kv_quant_mode,
k_scale_cache_ptr=k_scale_cache,
v_scale_cache_ptr=v_scale_cache,
stride_ks_blk=k_scale_cache.stride(0) if k_scale_cache is not None else 0,
stride_ks_slot=k_scale_cache.stride(1) if k_scale_cache is not None else 0,
stride_ks_head=k_scale_cache.stride(2) if k_scale_cache is not None else 0,
stride_vs_blk=v_scale_cache.stride(0) if v_scale_cache is not None else 0,
stride_vs_slot=v_scale_cache.stride(1) if v_scale_cache is not None else 0,
stride_vs_head=v_scale_cache.stride(2) if v_scale_cache is not None else 0,
)
reduce_segments[(q.shape[0], num_query_heads)](
output_ptr=out,

View File

@@ -1,21 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import copy
from dataclasses import dataclass, fields, replace
from enum import IntEnum
from math import prod
from typing import TYPE_CHECKING
import torch
from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_dtype_size
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# KV cache quantization mode
# ---------------------------------------------------------------------------
class KVQuantMode(IntEnum):
"""KV cache quantization mode.
Used by attention backends and kernels to dispatch quantization logic
without string matching on ``kv_cache_dtype``.
"""
NONE = 0
FP8_PER_TENSOR = 1 # per-tensor scales (current fp8 path)
INT8_PER_TOKEN_HEAD = 2 # per-token-head dynamic scales for int8
FP8_PER_TOKEN_HEAD = 3 # per-token-head dynamic scales for fp8
@property
def is_per_token_head(self) -> bool:
"""True for any per-token-head quantization mode."""
return self >= 2
def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
"""Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`."""
if kv_cache_dtype == "int8_per_token_head":
return KVQuantMode.INT8_PER_TOKEN_HEAD
if kv_cache_dtype == "fp8_per_token_head":
return KVQuantMode.FP8_PER_TOKEN_HEAD
if kv_cache_dtype.startswith("fp8"):
return KVQuantMode.FP8_PER_TENSOR
return KVQuantMode.NONE
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return get_kv_quant_mode(kv_cache_dtype) != KVQuantMode.NONE
def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
"""Return True if *kv_cache_dtype* needs per-token-head scales."""
return get_kv_quant_mode(kv_cache_dtype).is_per_token_head
@dataclass(frozen=True)
class KVCacheSpec:
"""
@@ -66,11 +115,19 @@ class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
kv_quant_mode: KVQuantMode = KVQuantMode.NONE
page_size_padded: int | None = None
@property
def page_size_bytes(self) -> int:
real_page_size = self.real_page_size_bytes
# Per-token-head scales are stored in separate tensors managed
# by the attention backend, but the memory is carved from the
# raw KV cache allocation so it must be budgeted here.
if self.kv_quant_mode.is_per_token_head:
real_page_size += (
2 * self.block_size * self.num_kv_heads * get_dtype_size(torch.float32)
)
if self.page_size_padded is not None:
assert self.page_size_padded >= real_page_size
return self.page_size_padded
@@ -159,6 +216,7 @@ class FullAttentionSpec(AttentionSpec):
head_size=specs[0].head_size,
head_size_v=specs[0].head_size_v,
dtype=specs[0].dtype,
kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
@@ -220,6 +278,7 @@ class MLAAttentionSpec(FullAttentionSpec):
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded,
cache_dtype_str=cache_dtype_str_set.pop(),
)
@@ -352,6 +411,7 @@ class SinkFullAttentionSpec(FullAttentionSpec):
head_size_v=specs[0].head_size_v,
sink_len=specs[0].sink_len,
dtype=specs[0].dtype,
kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),