diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index a250a8be0..242cc6b3b 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -177,7 +177,7 @@ Priority is **1 = highest** (tried first). | `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | | `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A | | `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | -| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | +| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `int8_per_token_head`, `fp8_per_token_head` | %16 | Any | ✅ | ✅ | ❌ | All | Any | > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > diff --git a/tests/models/quantization/test_per_token_kv_cache.py b/tests/models/quantization/test_per_token_kv_cache.py new file mode 100644 index 000000000..c581f01eb --- /dev/null +++ b/tests/models/quantization/test_per_token_kv_cache.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""End-to-end accuracy tests for per-token-head KV cache quantization. + +Compares logprobs between a baseline bf16 model and the same model with +per-token-head quantized KV cache (int8 or fp8) using the Triton attention +backend. + +Run: pytest tests/models/quantization/test_per_token_kv_cache.py -v -s +""" + +import pytest + +from vllm.platforms import current_platform + +from ..utils import check_logprobs_close + + +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), + reason="Per-token-head KV cache requires CUDA or ROCm GPU.", +) +@pytest.mark.parametrize( + "base_model,test_model", + [ + ( + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), + ], +) +@pytest.mark.parametrize( + "kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"] +) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("backend", ["TRITON_ATTN"]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_per_token_head_kv_cache_accuracy( + vllm_runner, + example_prompts, + base_model: str, + test_model: str, + kv_cache_dtype: str, + max_tokens: int, + enforce_eager: bool, + backend: str, + tensor_parallel_size: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Compare logprobs between bf16 baseline and per-token-head quantized KV + cache. + + Uses calculate_kv_scales (dynamic scale computation) since there are + no per-token-head calibrated checkpoints available yet. + """ + with monkeypatch.context() as m: + m.setenv("TOKENIZERS_PARALLELISM", "true") + + MAX_MODEL_LEN = 1024 + NUM_LOG_PROBS = 8 + + with vllm_runner( + base_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype="auto", + attention_config={"backend": backend}, + ) as vllm_model: + baseline_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS + ) + + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + calculate_kv_scales=True, + attention_config={"backend": backend}, + ) as vllm_model: + test_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS + ) + + check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=test_outputs, + name_0="bf16_kv_cache", + name_1=f"{kv_cache_dtype}_kv_cache", + ) diff --git a/tests/quantization/test_per_token_kv_cache.py b/tests/quantization/test_per_token_kv_cache.py new file mode 100644 index 000000000..3e660e6b0 --- /dev/null +++ b/tests/quantization/test_per_token_kv_cache.py @@ -0,0 +1,560 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for per-token-head KV cache quantization (INT8 and FP8). + +Covers: +- Per-token-head Triton reshape-and-cache kernel +- Round-trip quantize/dequantize accuracy +- process_weights_after_loading early-return path +- End-to-end integration with Triton unified attention kernel + +Run: pytest tests/quantization/test_per_token_kv_cache.py -v -s +""" + +import random +from dataclasses import dataclass +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, +) +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed +from vllm.v1.kv_cache_interface import KVQuantMode, is_quantized_kv_cache + +# Skip entire module if no CUDA/ROCm GPU available +pytestmark = [ + pytest.mark.skipif( + not current_platform.is_cuda_alike(), + reason="Per-token-head KV cache tests require CUDA or ROCm GPU.", + ), +] + +# --------------------------------------------------------------------------- +# Test parameters +# --------------------------------------------------------------------------- +NUM_TOKENS = [1, 7, 42] +NUM_KV_HEADS = [1, 4, 8] +HEAD_SIZES = [64, 128] +BLOCK_SIZES = [16] +SEEDS = [0] + +# Platform-dependent FP8 dtype and range +FP8_DTYPE = current_platform.fp8_dtype() +FP8_MIN, FP8_MAX = get_fp8_min_max() + + +# --------------------------------------------------------------------------- +# Per-dtype quantization config +# --------------------------------------------------------------------------- +@dataclass(frozen=True) +class QuantConfig: + """Quantization parameters for a given cache dtype.""" + + cache_dtype: torch.dtype # torch.int8 or FP8_DTYPE + kv_cache_dtype_str: str # "int8_per_token_head" or "fp8_per_token_head" + quant_max: float + quant_min: float + kv_quant_mode: KVQuantMode + # INT8 Triton stores truncate; FP8 hardware casts round. + uses_trunc: bool + + +INT8_CONFIG = QuantConfig( + cache_dtype=torch.int8, + kv_cache_dtype_str="int8_per_token_head", + quant_max=127.0, + quant_min=-128.0, + kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD, + uses_trunc=True, +) +FP8_CONFIG = QuantConfig( + cache_dtype=FP8_DTYPE, + kv_cache_dtype_str="fp8_per_token_head", + quant_max=FP8_MAX, + quant_min=FP8_MIN, + kv_quant_mode=KVQuantMode.FP8_PER_TOKEN_HEAD, + uses_trunc=False, +) + +QUANT_CONFIGS = [INT8_CONFIG, FP8_CONFIG] + + +@pytest.fixture(params=QUANT_CONFIGS, ids=["int8", "fp8"]) +def qcfg(request) -> QuantConfig: + return request.param + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _quantize_per_token_head_ref( + data: torch.Tensor, # [num_tokens, num_heads, head_size] + cfg: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reference per-token-head quantization (one scale per token per head). + + Returns (quantized, scales) where scales is [num_tokens, num_heads]. + """ + absmax = data.float().abs().amax(dim=2) # [num_tokens, num_heads] + scales = (absmax / cfg.quant_max).clamp(min=1e-6) + scaled = data.float() * (1.0 / scales[:, :, None]) + if cfg.uses_trunc: + q = scaled.round().clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype) + else: + q = scaled.clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype) + return q, scales + + +# =========================================================================== +# 1. is_quantized_kv_cache / get_kv_quant_mode +# =========================================================================== +class TestIsQuantizedKvCache: + def test_fp8_variants(self): + assert is_quantized_kv_cache("fp8") + assert is_quantized_kv_cache("fp8_e4m3") + assert is_quantized_kv_cache("fp8_e5m2") + + def test_int8_per_token_head(self): + assert is_quantized_kv_cache("int8_per_token_head") + + def test_fp8_per_token_head(self): + assert is_quantized_kv_cache("fp8_per_token_head") + + def test_auto(self): + assert not is_quantized_kv_cache("auto") + + def test_bfloat16(self): + assert not is_quantized_kv_cache("bfloat16") + + def test_kv_quant_mode_int8(self): + from vllm.v1.kv_cache_interface import get_kv_quant_mode + + assert ( + get_kv_quant_mode("int8_per_token_head") == KVQuantMode.INT8_PER_TOKEN_HEAD + ) + + def test_kv_quant_mode_fp8(self): + from vllm.v1.kv_cache_interface import get_kv_quant_mode + + assert get_kv_quant_mode("fp8_per_token_head") == KVQuantMode.FP8_PER_TOKEN_HEAD + + +# =========================================================================== +# 2. Triton per-token-head kernel (reshape-and-cache) +# =========================================================================== +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_KV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_reshape_and_cache_per_token_head( + qcfg: QuantConfig, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + seed: int, +): + """Test triton_reshape_and_cache_flash_per_token_head_quant kernel.""" + from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_per_token_head_quant, + ) + + set_random_seed(seed) + torch.set_default_device("cuda") + + num_blocks = (num_tokens + block_size - 1) // block_size + 4 + + key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) + value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) + + key_cache = torch.zeros( + num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype + ) + value_cache = torch.zeros( + num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype + ) + k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) + v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) + + num_slots = block_size * num_blocks + slot_mapping = torch.tensor( + random.sample(range(num_slots), num_tokens), dtype=torch.long + ) + + triton_reshape_and_cache_flash_per_token_head_quant( + key, + value, + key_cache, + value_cache, + k_scale_cache, + v_scale_cache, + slot_mapping, + ) + + # Reference + ref_k_quant, ref_k_scales = _quantize_per_token_head_ref(key, qcfg) + ref_v_quant, ref_v_scales = _quantize_per_token_head_ref(value, qcfg) + + # Compare dequantized values rather than raw quantized values. + # Triton and PyTorch reductions can differ at FP8 rounding boundaries + # (up to 32 in quantized domain for fp8_e4m3), but the dequantized + # error is bounded by the scale. + for i, slot in enumerate(slot_mapping.tolist()): + blk = slot // block_size + off = slot % block_size + + actual_k_scale = k_scale_cache[blk, off] # [num_heads] + k_deq = key_cache[blk, off].float() * actual_k_scale[:, None] + k_ref_deq = key[i].float() + torch.testing.assert_close( + k_deq, + k_ref_deq, + atol=0.1, + rtol=0.1, + ) + actual_v_scale = v_scale_cache[blk, off] # [num_heads] + v_deq = value_cache[blk, off].float() * actual_v_scale[:, None] + v_ref_deq = value[i].float() + torch.testing.assert_close( + v_deq, + v_ref_deq, + atol=0.1, + rtol=0.1, + ) + # Per-head scales: [num_heads] + torch.testing.assert_close( + k_scale_cache[blk, off], ref_k_scales[i], atol=1e-4, rtol=1e-3 + ) + torch.testing.assert_close( + v_scale_cache[blk, off], ref_v_scales[i], atol=1e-4, rtol=1e-3 + ) + + +# =========================================================================== +# 3. Per-token-head round-trip accuracy (quantize -> dequantize) +# =========================================================================== +@pytest.mark.parametrize("num_tokens", [1, 16]) +@pytest.mark.parametrize("num_heads", [4]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("block_size", [16]) +@torch.inference_mode() +def test_per_token_head_round_trip_accuracy( + qcfg: QuantConfig, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, +): + """Verify per-token-head round-trip: kernel dequant matches reference. + + INT8: Triton truncates on float->int8 store. + FP8: hardware cast (clamp then cast). + """ + from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_per_token_head_quant, + ) + + torch.set_default_device("cuda") + set_random_seed(42) + + num_blocks = (num_tokens + block_size - 1) // block_size + 2 + + key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5 + value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5 + + key_cache = torch.zeros( + num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype + ) + value_cache = torch.zeros( + num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype + ) + k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) + v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) + + slot_mapping = torch.arange(num_tokens, dtype=torch.long) + + triton_reshape_and_cache_flash_per_token_head_quant( + key, + value, + key_cache, + value_cache, + k_scale_cache, + v_scale_cache, + slot_mapping, + ) + + for i in range(num_tokens): + blk = i // block_size + off = i % block_size + + for label, data, cache, sc in [ + ("key", key, key_cache, k_scale_cache), + ("val", value, value_cache, v_scale_cache), + ]: + for h in range(num_heads): + orig = data[i, h].float() # [head_size] + + actual_q = cache[blk, off, h] + actual_sc = sc[blk, off, h] + actual_deq = actual_q.float() * actual_sc + + # Round-trip: dequantized should be close to original + torch.testing.assert_close( + actual_deq, + orig, + atol=0.1, + rtol=0.1, + ) + + +# =========================================================================== +# 4. Negative slot mapping (padding tokens should be skipped) +# =========================================================================== +@torch.inference_mode() +def test_per_token_head_negative_slot_skipped(qcfg: QuantConfig): + """Tokens with slot_mapping=-1 should leave the cache unchanged.""" + from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_per_token_head_quant, + ) + + torch.set_default_device("cuda") + num_tokens = 4 + num_heads = 2 + head_size = 64 + block_size = 16 + num_blocks = 2 + + key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) + value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) + + key_cache = torch.zeros( + num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype + ) + value_cache = torch.zeros( + num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype + ) + k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) + v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) + + slot_mapping = torch.tensor([0, -1, 1, -1], dtype=torch.long) + + key_cache_before = key_cache.clone() + val_cache_before = value_cache.clone() + + triton_reshape_and_cache_flash_per_token_head_quant( + key, + value, + key_cache, + value_cache, + k_scale_cache, + v_scale_cache, + slot_mapping, + ) + + # Slots 0 and 1 should have been written (tokens 0 and 2) + assert not torch.equal(key_cache[0, 0], key_cache_before[0, 0]) + assert not torch.equal(key_cache[0, 1], key_cache_before[0, 1]) + assert not torch.equal(value_cache[0, 0], val_cache_before[0, 0]) + + # All other slots should be unchanged + assert torch.equal(key_cache[0, 2:], key_cache_before[0, 2:]) + assert torch.equal(key_cache[1], key_cache_before[1]) + assert torch.equal(value_cache[0, 2:], val_cache_before[0, 2:]) + + +# =========================================================================== +# 5. process_weights_after_loading -- per-token-head early return +# =========================================================================== +@pytest.mark.parametrize( + "kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"] +) +def test_process_weights_sets_placeholder_scales(kv_cache_dtype: str): + """Per-token-head should set _k_scale=1.0, _v_scale=1.0 + and delete checkpoint attrs.""" + from vllm.model_executor.layers.quantization.kv_cache import ( + BaseKVCacheMethod, + ) + + layer = MagicMock() + layer.kv_cache_dtype = kv_cache_dtype + layer.calculate_kv_scales = False + layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer._k_scale = torch.tensor(0.0) + layer._v_scale = torch.tensor(0.0) + layer._k_scale_float = 0.0 + layer._v_scale_float = 0.0 + + method = BaseKVCacheMethod.__new__(BaseKVCacheMethod) + method.quant_config = MagicMock() + method.process_weights_after_loading(layer) + + assert layer._k_scale_float == 1.0 + assert layer._v_scale_float == 1.0 + assert not hasattr(layer, "k_scale") + assert not hasattr(layer, "v_scale") + assert not hasattr(layer, "q_scale") + assert not hasattr(layer, "prob_scale") + + +# =========================================================================== +# 6. Triton unified_attention -- per-token-head scale cache (INT8 and FP8) +# =========================================================================== +@pytest.mark.parametrize( + "seq_lens", + [ + [(1, 128)], + [(1, 64), (1, 32)], + ], +) +@pytest.mark.parametrize("num_heads", [(4, 4)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("block_size", [16]) +@torch.inference_mode() +def test_triton_unified_attention_per_token_head_scale( + qcfg: QuantConfig, + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + block_size: int, +): + """End-to-end: quantized KV with per-token-head scale caches.""" + from vllm.utils.math_utils import next_power_of_2 + from vllm.v1.attention.ops.triton_unified_attention import unified_attention + + torch.set_default_device("cuda") + set_random_seed(0) + + num_seqs = len(seq_lens) + query_lens = [s[0] for s in seq_lens] + kv_lens = [s[1] for s in seq_lens] + num_query_heads, num_kv_heads = num_heads + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + num_blocks = 2048 + + query = torch.randn( + sum(query_lens), num_query_heads, head_size, dtype=torch.bfloat16 + ) + + key_cache_bf16 = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=torch.bfloat16 + ) + value_cache_bf16 = torch.randn_like(key_cache_bf16) + + # Per-token-head quantization: one scale per (block, slot, head) + k_absmax = key_cache_bf16.float().abs().amax(dim=-1) # [..., num_kv_heads] + v_absmax = value_cache_bf16.float().abs().amax(dim=-1) + k_scale_cache = (k_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32) + v_scale_cache = (v_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32) + + scaled_k = key_cache_bf16.float() / k_scale_cache[:, :, :, None] + scaled_v = value_cache_bf16.float() / v_scale_cache[:, :, :, None] + if qcfg.uses_trunc: + key_cache_q = ( + scaled_k.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype) + ) + value_cache_q = ( + scaled_v.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype) + ) + else: + key_cache_q = scaled_k.clamp(qcfg.quant_min, qcfg.quant_max).to( + qcfg.cache_dtype + ) + value_cache_q = scaled_v.clamp(qcfg.quant_min, qcfg.quant_max).to( + qcfg.cache_dtype + ) + + # Dequantized reference + key_cache_deq = key_cache_q.float() * k_scale_cache[:, :, :, None] + value_cache_deq = value_cache_q.float() * v_scale_cache[:, :, :, None] + + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) + kv_lens_t = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) + + head_size_padded = next_power_of_2(head_size) + seq_threshold_3D = 0 + num_par_softmax_segments = 16 + softmax_segm_output = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded), + dtype=torch.float32, + ) + softmax_segm_max = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + softmax_segm_expsum = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + + output_q = torch.empty_like(query) + unified_attention( + q=query, + k=key_cache_q, + v=value_cache_q, + out=output_q, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens_t, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + block_table=block_tables, + softcap=0, + q_descale=None, + k_descale=None, + v_descale=None, + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, + kv_quant_mode=qcfg.kv_quant_mode, + k_scale_cache=k_scale_cache, + v_scale_cache=v_scale_cache, + ) + + output_ref = torch.empty_like(query) + unified_attention( + q=query, + k=key_cache_deq.to(torch.bfloat16), + v=value_cache_deq.to(torch.bfloat16), + out=output_ref, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens_t, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + block_table=block_tables, + softcap=0, + q_descale=None, + k_descale=None, + v_descale=None, + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, + ) + + torch.testing.assert_close(output_q, output_ref, atol=5e-2, rtol=5e-2) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 1fdce002e..cd1554590 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -8,7 +8,10 @@ from pydantic import Field, SkipValidation, field_validator, model_validator from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.utils.torch_utils import ( + is_quantized_kv_cache, + kv_cache_uses_per_token_head_scales, +) logger = init_logger(__name__) @@ -21,6 +24,8 @@ CacheDType = Literal[ "fp8_e5m2", "fp8_inc", "fp8_ds_mla", + "int8_per_token_head", + "fp8_per_token_head", ] MambaDType = Literal["auto", "float32", "float16"] MambaCacheMode = Literal["all", "align", "none"] @@ -237,12 +242,20 @@ class CacheConfig: @field_validator("cache_dtype", mode="after") @classmethod def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: - if is_quantized_kv_cache(cache_dtype): + if kv_cache_uses_per_token_head_scales(cache_dtype): logger.info( - "Using fp8 data type to store kv cache. It reduces the GPU " + "Using %s data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Dynamic per-token-head scales will be computed at runtime.", + str(cache_dtype), + ) + elif is_quantized_kv_cache(cache_dtype): + logger.info( + "Using %s data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor." + "scaling factor", + str(cache_dtype), ) return cache_dtype diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 7610030f3..3ff4ec62a 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -38,6 +38,7 @@ from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheSpec, SlidingWindowSpec, + get_kv_quant_mode, ) if TYPE_CHECKING: @@ -381,8 +382,10 @@ class Attention(nn.Module, AttentionLayerBase): # for attn backends supporting query quantization self.query_quant = None - if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith( - "fp8" + if ( + self.impl.supports_quant_query_input + and self.kv_cache_dtype.startswith("fp8") + and not self.kv_cache_dtype.endswith("per_token_head") ): is_per_head = ( hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads @@ -539,6 +542,7 @@ class Attention(nn.Module, AttentionLayerBase): block_size = vllm_config.cache_config.block_size # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER + quant_mode = get_kv_quant_mode(self.kv_cache_dtype) if self.sliding_window is not None: assert not vllm_config.model_config.use_mla, ( "MLA is not supported for slidingwindow" @@ -548,6 +552,7 @@ class Attention(nn.Module, AttentionLayerBase): num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_torch_dtype, + kv_quant_mode=quant_mode, sliding_window=self.sliding_window, ) else: @@ -557,6 +562,7 @@ class Attention(nn.Module, AttentionLayerBase): head_size=self.head_size, head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, + kv_quant_mode=quant_mode, ) diff --git a/vllm/model_executor/layers/attention/chunked_local_attention.py b/vllm/model_executor/layers/attention/chunked_local_attention.py index b747304ac..136574d97 100644 --- a/vllm/model_executor/layers/attention/chunked_local_attention.py +++ b/vllm/model_executor/layers/attention/chunked_local_attention.py @@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import ( AttentionSpec, ChunkedLocalAttentionSpec, KVCacheSpec, + get_kv_quant_mode, ) @@ -123,5 +124,6 @@ class ChunkedLocalAttention(Attention): num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_torch_dtype, + kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype), attention_chunk_size=self.attention_chunk_size, ) diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 5bd8e163f..31ac7fa1b 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -18,7 +18,11 @@ from vllm.v1.attention.backend import ( subclass_attention_backend_with_overrides, ) from vllm.v1.attention.selector import get_attn_backend -from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec +from vllm.v1.kv_cache_interface import ( + CrossAttentionSpec, + KVCacheSpec, + get_kv_quant_mode, +) logger = init_logger(__name__) @@ -220,4 +224,5 @@ class CrossAttention(Attention): num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_torch_dtype, + kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype), ) diff --git a/vllm/model_executor/layers/attention/static_sink_attention.py b/vllm/model_executor/layers/attention/static_sink_attention.py index 913d73a16..263d87321 100644 --- a/vllm/model_executor/layers/attention/static_sink_attention.py +++ b/vllm/model_executor/layers/attention/static_sink_attention.py @@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import ( AttentionSpec, KVCacheSpec, SinkFullAttentionSpec, + get_kv_quant_mode, ) logger = init_logger(__name__) @@ -217,6 +218,7 @@ class StaticSinkAttention(Attention, CustomOp): head_size_v=self.head_size_v, sink_len=self.sink_len, dtype=self.kv_cache_torch_dtype, + kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype), ) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 2fb67aacc..726ac2232 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ) from vllm.platforms import current_platform from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales logger = init_logger(__name__) @@ -53,6 +54,20 @@ class BaseKVCacheMethod(QuantizeMethodBase): assert not hasattr(layer, "prob_scale") return + # Per-token-head quantized KV cache: scales are computed dynamically + # per (token, head) in the kernel at cache-write time. Checkpoint + # scales are never used regardless of calculate_kv_scales. + if kv_cache_uses_per_token_head_scales(layer.kv_cache_dtype): + layer._k_scale.copy_(1.0) + layer._v_scale.copy_(1.0) + layer._k_scale_float = 1.0 + layer._v_scale_float = 1.0 + del layer.k_scale + del layer.v_scale + del layer.q_scale + del layer.prob_scale + return + # If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 7fba7a65f..1f54004f7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -505,6 +505,7 @@ class Platform: FullAttentionSpec, MambaSpec, MLAAttentionSpec, + get_kv_quant_mode, ) cache_config = vllm_config.cache_config @@ -516,6 +517,8 @@ class Platform: else: kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + kv_quant_mode = get_kv_quant_mode(cache_config.cache_dtype) + # Compute attention page size for 1 token if model_config.use_mla: attn_page_size_1_token = MLAAttentionSpec( @@ -523,6 +526,7 @@ class Platform: num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, + kv_quant_mode=kv_quant_mode, ).page_size_bytes else: attn_page_size_1_token = FullAttentionSpec( @@ -530,6 +534,7 @@ class Platform: num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, + kv_quant_mode=kv_quant_mode, ).page_size_bytes # Compute mamba page size diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 59c19a56e..94f8c096e 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -37,6 +37,8 @@ STR_DTYPE_TO_TORCH_DTYPE = { "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, "int8": torch.int8, + "int8_per_token_head": torch.int8, + "fp8_per_token_head": torch.uint8, "fp8_inc": torch.float8_e4m3fn, "fp8_ds_mla": torch.uint8, } @@ -62,7 +64,12 @@ T = TypeVar("T") def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: - return kv_cache_dtype.startswith("fp8") + return kv_cache_dtype.startswith("fp8") or kv_cache_dtype.endswith("per_token_head") + + +def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool: + """Return True if *kv_cache_dtype* needs per-token-head scales.""" + return kv_cache_dtype.endswith("per_token_head") def is_strictly_contiguous(t: torch.Tensor) -> bool: diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 32fac520c..4663cb71d 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -17,7 +17,9 @@ if TYPE_CHECKING: from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.utils import KVCacheLayoutType - from vllm.v1.kv_cache_interface import AttentionSpec + from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode + +from vllm.v1.kv_cache_interface import get_kv_quant_mode class AttentionType(str, Enum): @@ -740,6 +742,13 @@ class AttentionImplBase(ABC, Generic[T]): class AttentionImpl(AttentionImplBase[T], Generic[T]): """Standard attention implementation with forward method.""" + kv_cache_dtype: str + + @property + def kv_quant_mode(self) -> "KVQuantMode": + """Return the KV cache quantization mode for this layer.""" + return get_kv_quant_mode(self.kv_cache_dtype) + @abstractmethod def __init__( self, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 3dd081745..5b1eec385 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -33,9 +33,14 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, + triton_reshape_and_cache_flash_per_token_head_quant, ) from vllm.v1.attention.ops.triton_unified_attention import unified_attention -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + get_kv_quant_mode, + kv_cache_uses_per_token_head_scales, +) logger = init_logger(__name__) @@ -270,6 +275,8 @@ class TritonAttentionBackend(AttentionBackend): "fp8", "fp8_e4m3", "fp8_e5m2", + "int8_per_token_head", + "fp8_per_token_head", ] @staticmethod @@ -302,6 +309,18 @@ class TritonAttentionBackend(AttentionBackend): ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") + if kv_cache_uses_per_token_head_scales(cache_dtype_str): + # Pad head_size by sizeof(float32)/sizeof(cache_dtype) so + # the per-head scale fits inline. The backend extracts + # data[:head_size] and scale[head_size:] via typed views. + from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + get_dtype_size, + ) + + cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype_str] + scale_pad = get_dtype_size(torch.float32) // get_dtype_size(cache_dtype) + return (num_blocks, 2, block_size, num_kv_heads, head_size + scale_pad) return (num_blocks, 2, block_size, num_kv_heads, head_size) @staticmethod @@ -365,6 +384,62 @@ class TritonAttentionBackend(AttentionBackend): class TritonAttentionImpl(AttentionImpl): + # Per-token-head quant: scale views carved from inline head padding. + _k_scale_cache: torch.Tensor | None = None + _v_scale_cache: torch.Tensor | None = None + + def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None: + """Extract per-head scale views from the padded head dimension. + + The KV cache shape is ``(num_blocks, 2, block_size, nkv, hs+pad)`` + where ``pad = sizeof(float32) / sizeof(cache_dtype)``. The last + ``pad`` elements of each head hold one float32 scale. We create + strided float32 views over those bytes. + + Scale shape: ``(num_blocks, block_size, num_kv_heads)`` + """ + if self._k_scale_cache is not None: + return + from vllm.utils.torch_utils import get_dtype_size + + num_blocks, _, block_size, nkv, padded_hs = kv_cache.shape + dtype_sz = kv_cache.element_size() + scale_pad = get_dtype_size(torch.float32) // dtype_sz # e.g. 4 + hs = padded_hs - scale_pad + + raw = kv_cache.untyped_storage() + base_f32 = torch.tensor([], dtype=torch.float32, device=kv_cache.device).set_( + raw + ) + + # In the raw bytes, each (block, kv_half, slot, head) occupies + # padded_hs * dtype_sz bytes. The scale float32 sits at byte + # offset hs * dtype_sz within that region. + kv_half_bytes = block_size * nkv * padded_hs * dtype_sz + full_block_f32 = 2 * kv_half_bytes // 4 # stride between blocks + slot_f32 = nkv * padded_hs * dtype_sz // 4 # stride between slots + head_f32 = padded_hs * dtype_sz // 4 # stride between heads + scale_off_f32 = hs * dtype_sz // 4 # offset to scale within head + + # K scales: kv_half=0 + self._k_scale_cache = torch.as_strided( + base_f32, + size=(num_blocks, block_size, nkv), + stride=(full_block_f32, slot_f32, head_f32), + storage_offset=scale_off_f32, + ) + self._k_scale_cache.fill_(1.0) + + # V scales: kv_half=1, offset by kv_half_bytes + v_base_f32 = kv_half_bytes // 4 + self._v_scale_cache = torch.as_strided( + base_f32, + size=(num_blocks, block_size, nkv), + stride=(full_block_f32, slot_f32, head_f32), + storage_offset=v_base_f32 + scale_off_f32, + ) + self._v_scale_cache.fill_(1.0) + def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym @@ -418,6 +493,9 @@ class TritonAttentionImpl(AttentionImpl): self.use_alibi_sqrt = use_alibi_sqrt self.supports_quant_query_input = current_platform.is_cuda() + self._kv_quant_mode = get_kv_quant_mode(kv_cache_dtype) + self._is_per_token_head_quant = self._kv_quant_mode.is_per_token_head + def forward( self, layer: torch.nn.Module, @@ -480,15 +558,35 @@ class TritonAttentionImpl(AttentionImpl): layer, ) - # For decoder and cross-attention, use KV cache as before - key_cache, value_cache = kv_cache.unbind(1) - if is_quantized_kv_cache(self.kv_cache_dtype): - if key_cache.dtype != self.fp8_dtype: + # Per-token-head quantized KV cache: use separate scale caches. + if self._is_per_token_head_quant: + self._ensure_scale_caches(kv_cache) + key_cache, value_cache = kv_cache.unbind(1) + if key_cache.dtype == torch.uint8: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) - assert layer._q_scale_float == 1.0, ( - "A non 1.0 q_scale is not currently supported." + k_descale = None + v_descale = None + k_scale_cache = self._k_scale_cache + v_scale_cache = self._v_scale_cache + # FP8 per-tensor / auto path (original flow). + else: + key_cache, value_cache = kv_cache.unbind(1) + if is_quantized_kv_cache(self.kv_cache_dtype): + if key_cache.dtype != self.fp8_dtype: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + assert layer._q_scale_float == 1.0, ( + "A non 1.0 q_scale is not currently supported." + ) + descale_shape = ( + attn_metadata.query_start_loc.shape[0] - 1, + key_cache.shape[2], ) + k_descale = layer._k_scale.expand(descale_shape) + v_descale = layer._v_scale.expand(descale_shape) + k_scale_cache = None + v_scale_cache = None cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens @@ -502,7 +600,6 @@ class TritonAttentionImpl(AttentionImpl): softmax_segm_max = attn_metadata.softmax_segm_max softmax_segm_expsum = attn_metadata.softmax_segm_expsum - descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor unified_attention( @@ -522,8 +619,8 @@ class TritonAttentionImpl(AttentionImpl): block_table=block_table, softcap=self.logits_soft_cap, q_descale=None, # Not supported - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), + k_descale=k_descale, + v_descale=v_descale, seq_threshold_3D=seq_threshold_3D, num_par_softmax_segments=num_par_softmax_segments, softmax_segm_output=softmax_segm_output, @@ -532,6 +629,9 @@ class TritonAttentionImpl(AttentionImpl): sinks=self.sinks, output_scale=output_scale, mm_prefix_range=mm_prefix_range_tensor, + kv_quant_mode=self._kv_quant_mode, + k_scale_cache=k_scale_cache, + v_scale_cache=v_scale_cache, ) return output @@ -555,10 +655,10 @@ class TritonAttentionImpl(AttentionImpl): attn_metadata: Encoder attention metadata layer: The attention layer """ - # For encoder attention, process FP8 quantization if needed + # Quantized KV cache is not supported for encoder attention. if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "quantization is not supported for encoder attention" + "quantized KV cache is not supported for encoder attention" ) # Use encoder-specific metadata for sequence information @@ -594,16 +694,28 @@ class TritonAttentionImpl(AttentionImpl): # For encoder attention, # we use direct Q, K, V tensors without caching return - # For decoder and cross-attention, use KV cache as before - key_cache, value_cache = kv_cache.unbind(1) - # Reshape the input keys and values and store them in the cache. + if self._is_per_token_head_quant: + self._ensure_scale_caches(kv_cache) + key_cache, value_cache = kv_cache.unbind(1) + if key_cache.dtype == torch.uint8: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + triton_reshape_and_cache_flash_per_token_head_quant( + key, + value, + key_cache, + value_cache, + self._k_scale_cache, + self._v_scale_cache, + slot_mapping, + ) + return + # For decoder and cross-attention, use KV cache as before. + key_cache, value_cache = kv_cache.unbind(1) if is_quantized_kv_cache(self.kv_cache_dtype): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) - # triton kernel does not support uint8 kv_cache - # (because some explicit casts (e.g. float8_e4m3fnuz) - # are not supported) triton_reshape_and_cache_flash( key, value, @@ -616,6 +728,8 @@ class TritonAttentionImpl(AttentionImpl): ) def fused_rope_kvcache_supported(self): + if self._is_per_token_head_quant: + return False return rocm_aiter_ops.is_enabled() def do_rope_and_kv_cache_update( diff --git a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py index eeec60962..6e696fdb5 100644 --- a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py @@ -3,10 +3,16 @@ import torch +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + FP8_DTYPE, + get_fp8_min_max, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import is_quantized_kv_cache +FP8_MIN, FP8_MAX = get_fp8_min_max() + @triton.jit def reshape_and_cache_kernel_flash( @@ -118,6 +124,198 @@ def reshape_and_cache_kernel_flash( return +# --------------------------------------------------------------------------- +# Per-token-head dynamic quantization kernel +# Grid: (num_tokens, NUM_KV_HEADS) +# Each program handles one (token, head) pair: +# 1. Loads K (or V) for that single head +# 2. Computes absmax across head_size → scale = absmax / QUANT_MAX +# 3. Quantizes and stores the data + per-head scale +# +# Parametrised by QUANT_MAX / QUANT_MIN so the same code path works +# for int8 (±127/128), fp8_e4m3 (±448), and other formats. +# --------------------------------------------------------------------------- +@triton.jit +def _reshape_cache_per_token_head( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size_v] + key_cache_ptr, # [num_blocks, block_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blocks, block_size, num_kv_heads, head_size_v] + k_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32 + v_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32 + slot_mapping_ptr, # [num_tokens] + stride_key_tok: tl.int64, + stride_key_head: tl.int64, + stride_val_tok: tl.int64, + stride_val_head: tl.int64, + stride_kc_blk: tl.int64, # key_cache stride over blocks + stride_kc_slot: tl.int64, # key_cache stride over slots + stride_kc_head: tl.int64, # key_cache stride over heads + stride_vc_blk: tl.int64, + stride_vc_slot: tl.int64, + stride_vc_head: tl.int64, + stride_ks_blk: tl.int64, # k_scale_cache stride[0] (blocks) + stride_ks_slot: tl.int64, # k_scale_cache stride[1] (slots) + stride_ks_head: tl.int64, # k_scale_cache stride[2] (heads) + stride_vs_blk: tl.int64, # v_scale_cache stride[0] (blocks) + stride_vs_slot: tl.int64, # v_scale_cache stride[1] (slots) + stride_vs_head: tl.int64, # v_scale_cache stride[2] (heads) + block_size: tl.constexpr, + head_size: tl.constexpr, + head_size_v: tl.constexpr, + HEAD_SIZE_PADDED: tl.constexpr, # next_power_of_2(max(head_size, head_size_v)) + QUANT_MAX: tl.constexpr = 127.0, + QUANT_MIN: tl.constexpr = -128.0, +): + tok = tl.program_id(0) + head = tl.program_id(1) + + slot = tl.load(slot_mapping_ptr + tok).to(tl.int64) + if slot < 0: + return + + blk = slot // block_size + slot_in_blk = slot % block_size + + dim_offs = tl.arange(0, HEAD_SIZE_PADDED) + + # ---- Key: load one head → absmax → quantize → store ------------------- + k_mask = dim_offs < head_size + k_h = tl.load( + key_ptr + tok * stride_key_tok + head * stride_key_head + dim_offs, + mask=k_mask, + other=0.0, + ).to(tl.float32) + + k_scale = tl.maximum(tl.max(tl.abs(k_h)) / QUANT_MAX, 1e-6) + tl.store( + k_scale_cache_ptr + + blk * stride_ks_blk + + slot_in_blk * stride_ks_slot + + head * stride_ks_head, + k_scale, + ) + + k_q = tl.clamp(k_h * (1.0 / k_scale), QUANT_MIN, QUANT_MAX) + tl.store( + key_cache_ptr + + blk * stride_kc_blk + + slot_in_blk * stride_kc_slot + + head * stride_kc_head + + dim_offs, + k_q, + mask=k_mask, + ) + + # ---- Value: same per-head approach ------------------------------------ + v_mask = dim_offs < head_size_v + v_h = tl.load( + value_ptr + tok * stride_val_tok + head * stride_val_head + dim_offs, + mask=v_mask, + other=0.0, + ).to(tl.float32) + + v_scale = tl.maximum(tl.max(tl.abs(v_h)) / QUANT_MAX, 1e-6) + tl.store( + v_scale_cache_ptr + + blk * stride_vs_blk + + slot_in_blk * stride_vs_slot + + head * stride_vs_head, + v_scale, + ) + + v_q = tl.clamp(v_h * (1.0 / v_scale), QUANT_MIN, QUANT_MAX) + tl.store( + value_cache_ptr + + blk * stride_vc_blk + + slot_in_blk * stride_vc_slot + + head * stride_vc_head + + dim_offs, + v_q, + mask=v_mask, + ) + + +# Mapping from cache torch dtype to (QUANT_MAX, QUANT_MIN) for the +# per-token-head quantization kernel. +_PER_TOKEN_HEAD_QUANT_PARAMS: dict[torch.dtype, tuple[float, float]] = { + torch.int8: (127.0, -128.0), + FP8_DTYPE: (FP8_MAX, FP8_MIN), +} + + +def triton_reshape_and_cache_flash_per_token_head_quant( + key: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + value: torch.Tensor, # [num_tokens, num_kv_heads, head_size_v] + key_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size] + value_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size_v] + k_scale_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads] float32 + v_scale_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads] float32 + slot_mapping: torch.Tensor, # [num_tokens] +): + """Quantize key/value per (token, head) and write to paged cache. + + Computes one scale = absmax / QUANT_MAX per (token, head), stores + quantized data in key_cache/value_cache, and stores the float32 + scale in k_scale_cache/v_scale_cache. + + The quantization range (QUANT_MAX, QUANT_MIN) is derived from the + cache tensor dtype so the same code path works for int8 and fp8. + """ + cache_dtype = key_cache.dtype + quant_params = _PER_TOKEN_HEAD_QUANT_PARAMS.get(cache_dtype) + if quant_params is None: + raise ValueError( + f"Per-token-head quantization not supported for cache dtype " + f"{cache_dtype}. Supported: {list(_PER_TOKEN_HEAD_QUANT_PARAMS)}" + ) + quant_max, quant_min = quant_params + + num_tokens, num_kv_heads, head_size = key.shape + head_size_v = value.shape[2] + head_size_padded = triton.next_power_of_2(max(head_size, head_size_v)) + + block_size = key_cache.shape[1] + + if current_platform.is_rocm() or current_platform.is_xpu(): + num_warps = 4 + else: + num_warps = min(16, max(1, head_size_padded // 32)) + + _reshape_cache_per_token_head[(num_tokens, num_kv_heads)]( + key_ptr=key, + value_ptr=value, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + k_scale_cache_ptr=k_scale_cache, + v_scale_cache_ptr=v_scale_cache, + slot_mapping_ptr=slot_mapping, + stride_key_tok=key.stride(0), + stride_key_head=key.stride(1), + stride_val_tok=value.stride(0), + stride_val_head=value.stride(1), + stride_kc_blk=key_cache.stride(0), + stride_kc_slot=key_cache.stride(1), + stride_kc_head=key_cache.stride(2), + stride_vc_blk=value_cache.stride(0), + stride_vc_slot=value_cache.stride(1), + stride_vc_head=value_cache.stride(2), + stride_ks_blk=k_scale_cache.stride(0), + stride_ks_slot=k_scale_cache.stride(1), + stride_ks_head=k_scale_cache.stride(2), + stride_vs_blk=v_scale_cache.stride(0), + stride_vs_slot=v_scale_cache.stride(1), + stride_vs_head=v_scale_cache.stride(2), + block_size=block_size, + head_size=head_size, + head_size_v=head_size_v, + HEAD_SIZE_PADDED=head_size_padded, + QUANT_MAX=quant_max, + QUANT_MIN=quant_min, + num_warps=num_warps, + ) + + def triton_reshape_and_cache_flash( key: torch.Tensor, # [num_tokens, num_heads, head_size] value: torch.Tensor, # [num_tokens, num_heads, head_size] @@ -224,7 +422,6 @@ def triton_reshape_and_cache_flash( block_size=block_size, x=x, USE_HEAD_MAJOR_LAYOUT=use_head_major_layout, - # FP8 flags FP8_KV_CACHE=FP8_KV_CACHE, # autotune parameters TILE_SIZE=TILE_SIZE, diff --git a/vllm/v1/attention/ops/triton_unified_attention.py b/vllm/v1/attention/ops/triton_unified_attention.py index ca5d0e336..150f022f8 100644 --- a/vllm/v1/attention/ops/triton_unified_attention.py +++ b/vllm/v1/attention/ops/triton_unified_attention.py @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +from vllm.v1.kv_cache_interface import KVQuantMode logger = init_logger(__name__) is_batch_invariant = envs.VLLM_BATCH_INVARIANT @@ -32,6 +33,63 @@ def apply_softcap(S, x): return x * (p1 - p2) / (p1 + p2) +@triton.jit +def _prepare_kv_tile( + data, + Q, + tensor_scale, + scale_cache_ptr, + physical_block_idx, + seq_offset, + kv_head_idx, + stride_s_blk, + stride_s_slot, + stride_s_head, + tile_mask, + BLOCK_SIZE: tl.constexpr, + KV_QUANT_MODE: tl.constexpr, +): + """Prepare a loaded KV tile for attention computation. + + Casts the raw KV data to Q's dtype and loads per-token-head scales + when applicable: + + - ``KV_QUANT_MODE == 0``: cast only (no-op for bf16/fp16). + - ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize inline + using the tensor-wide scale. + - ``KV_QUANT_MODE >= 2`` (per-token-head int8/fp8): cast to Q's + dtype and return per-head scales separately — the caller applies + them after the dot product for better numerical efficiency. + + Returns ``(data, token_head_scales)``. *token_head_scales* is only + meaningful when ``KV_QUANT_MODE >= 2``; callers gate its use on + the same constexpr so the compiler eliminates dead code. + """ + # KV_QUANT_MODE values: 0=none, 1=fp8 per-tensor, + # 2=int8 per-token-head, 3=fp8 per-token-head + + # Placeholder scales (float32) — never read when KV_QUANT_MODE < 2. + unused_scales = tile_mask.to(tl.float32) + + if KV_QUANT_MODE == 1: # FP8 per-tensor + if Q.dtype.is_fp8(): + return data.to(Q.dtype), unused_scales + return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype), unused_scales + if KV_QUANT_MODE >= 2: # per-token-head (int8 or fp8) + scale_idx = ( + physical_block_idx * stride_s_blk + + (seq_offset % BLOCK_SIZE) * stride_s_slot + + kv_head_idx * stride_s_head + ) + token_head_scales = tl.load( + scale_cache_ptr + scale_idx, mask=tile_mask, other=1.0 + ) + return data.to(Q.dtype), token_head_scales + # .to(Q.dtype) is a no-op when data is already Q's type (bf16/fp16), + # but required so Triton sees consistent return types across branches. + return data.to(Q.dtype), unused_scales + + @triton.jit def find_seq_idx( query_start_len_ptr, @@ -105,8 +163,20 @@ def kernel_unified_attention_2d( num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int USE_FP8: tl.constexpr, # bool + # KV cache quantization: 0=none, 1=fp8, 2=per-token-head + KV_QUANT_MODE: tl.constexpr = 0, FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, + # Per-token-head scale caches (KV_QUANT_MODE >= 2) + # Shape: [num_blocks, block_size, num_kv_heads] + k_scale_cache_ptr=None, + v_scale_cache_ptr=None, + stride_ks_blk=0, + stride_ks_slot=0, + stride_ks_head=0, + stride_vs_blk=0, + stride_vs_slot=0, + stride_vs_head=0, ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -258,14 +328,21 @@ def kernel_unified_attention_2d( mask=dim_mask[:, None] & tile_mask[None, :], other=0.0, ) - - if K_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): - K = K_load - else: - K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) - else: - K = K_load + K, k_token_head_scales = _prepare_kv_tile( + K_load, + Q, + k_scale, + k_scale_cache_ptr, + physical_block_idx, + seq_offset, + kv_head_idx, + stride_ks_blk, + stride_ks_slot, + stride_ks_head, + tile_mask, + BLOCK_SIZE, + KV_QUANT_MODE, + ) # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( @@ -273,14 +350,21 @@ def kernel_unified_attention_2d( mask=dim_mask[None, :] & tile_mask[:, None], other=0.0, ) - - if V_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): - V = V_load - else: - V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) - else: - V = V_load + V, v_token_head_scales = _prepare_kv_tile( + V_load, + Q, + v_scale, + v_scale_cache_ptr, + physical_block_idx, + seq_offset, + kv_head_idx, + stride_vs_blk, + stride_vs_slot, + stride_vs_head, + tile_mask, + BLOCK_SIZE, + KV_QUANT_MODE, + ) # Compute attention mask: causal by default (key <= query) query_abs_pos = context_len + query_pos[:, None] @@ -318,7 +402,12 @@ def kernel_unified_attention_2d( # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) - S += scale * tl.dot(Q, K) + # Per-token-head quant: fuse softmax_scale with per-head k_scale + # to avoid a separate BLOCK_M × TILE_SIZE multiply on S. + if KV_QUANT_MODE >= 2: + S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :]) + else: + S += scale * tl.dot(Q, K) if USE_SOFTCAP: S = apply_softcap(S, softcap) @@ -382,7 +471,12 @@ def kernel_unified_attention_2d( ) # acc : (BLOCK_M, HEAD_SIZE_PADDED) - acc += tl.dot(P.to(V.dtype), V) + # Per-token-head quant: apply v_scale to P instead of V. + if KV_QUANT_MODE >= 2: + P_v = (P * v_token_head_scales[None, :]).to(V.dtype) + acc += tl.dot(P_v, V) + else: + acc += tl.dot(P.to(V.dtype), V) # epilogue acc = acc / L[:, None] @@ -453,6 +547,18 @@ def kernel_unified_attention_3d( USE_MM_PREFIX: tl.constexpr, # bool MAX_MM_RANGES: tl.constexpr, # int mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence + # KV cache quantization: 0=none, 1=fp8, 2=per-token-head + KV_QUANT_MODE: tl.constexpr = 0, + # Per-token-head scale caches (KV_QUANT_MODE >= 2) + # Shape: [num_blocks, block_size, num_kv_heads] + k_scale_cache_ptr=None, + v_scale_cache_ptr=None, + stride_ks_blk=0, + stride_ks_slot=0, + stride_ks_head=0, + stride_vs_blk=0, + stride_vs_slot=0, + stride_vs_head=0, ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -613,14 +719,21 @@ def kernel_unified_attention_3d( mask=dim_mask[:, None] & tile_mask[None, :], other=0.0, ) - - if K_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): - K = K_load - else: - K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) - else: - K = K_load + K, k_token_head_scales = _prepare_kv_tile( + K_load, + Q, + k_scale, + k_scale_cache_ptr, + physical_block_idx, + seq_offset, + kv_head_idx, + stride_ks_blk, + stride_ks_slot, + stride_ks_head, + tile_mask, + BLOCK_SIZE, + KV_QUANT_MODE, + ) # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( @@ -628,14 +741,21 @@ def kernel_unified_attention_3d( mask=dim_mask[None, :] & tile_mask[:, None], other=0.0, ) - - if V_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): - V = V_load - else: - V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) - else: - V = V_load + V, v_token_head_scales = _prepare_kv_tile( + V_load, + Q, + v_scale, + v_scale_cache_ptr, + physical_block_idx, + seq_offset, + kv_head_idx, + stride_vs_blk, + stride_vs_slot, + stride_vs_head, + tile_mask, + BLOCK_SIZE, + KV_QUANT_MODE, + ) # Compute attention mask: causal by default (key <= query) query_abs_pos = context_len + query_pos[:, None] @@ -672,7 +792,13 @@ def kernel_unified_attention_3d( # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) - S += scale * tl.dot(Q, K) + + # Per-token-head quant: fuse softmax_scale with per-head k_scale + # to avoid a separate BLOCK_M × TILE_SIZE multiply on S. + if KV_QUANT_MODE >= 2: + S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :]) + else: + S += scale * tl.dot(Q, K) if USE_SOFTCAP: S = apply_softcap(S, softcap) @@ -736,7 +862,12 @@ def kernel_unified_attention_3d( ) # acc : (BLOCK_M, HEAD_SIZE_PADDED) - acc += tl.dot(P.to(V.dtype), V) + # Per-token-head quant: apply v_scale to P instead of V. + if KV_QUANT_MODE >= 2: + P_v = (P * v_token_head_scales[None, :]).to(V.dtype) + acc += tl.dot(P_v, V) + else: + acc += tl.dot(P.to(V.dtype), V) segm_output_offset = ( query_offset_0[:, None].to(tl.int64) @@ -911,6 +1042,10 @@ def unified_attention( # Optional tensor for prefix lengths (PrefixLM support) mm_prefix_range=None, use_alibi_sqrt=False, + # KV cache quantization mode and per-token-head scale caches. + kv_quant_mode: KVQuantMode = KVQuantMode.NONE, + k_scale_cache=None, # [num_blocks, block_size, num_kv_heads] float32 + v_scale_cache=None, # [num_blocks, block_size, num_kv_heads] float32 ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -1040,6 +1175,15 @@ def unified_attention( num_seqs=num_seqs, BLOCK_M=BLOCK_M, USE_FP8=output_scale is not None, + KV_QUANT_MODE=kv_quant_mode, + k_scale_cache_ptr=k_scale_cache, + v_scale_cache_ptr=v_scale_cache, + stride_ks_blk=k_scale_cache.stride(0) if k_scale_cache is not None else 0, + stride_ks_slot=k_scale_cache.stride(1) if k_scale_cache is not None else 0, + stride_ks_head=k_scale_cache.stride(2) if k_scale_cache is not None else 0, + stride_vs_blk=v_scale_cache.stride(0) if v_scale_cache is not None else 0, + stride_vs_slot=v_scale_cache.stride(1) if v_scale_cache is not None else 0, + stride_vs_head=v_scale_cache.stride(2) if v_scale_cache is not None else 0, ) else: kernel_unified_attention_3d[ @@ -1092,6 +1236,15 @@ def unified_attention( num_seqs=num_seqs, BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, + KV_QUANT_MODE=kv_quant_mode, + k_scale_cache_ptr=k_scale_cache, + v_scale_cache_ptr=v_scale_cache, + stride_ks_blk=k_scale_cache.stride(0) if k_scale_cache is not None else 0, + stride_ks_slot=k_scale_cache.stride(1) if k_scale_cache is not None else 0, + stride_ks_head=k_scale_cache.stride(2) if k_scale_cache is not None else 0, + stride_vs_blk=v_scale_cache.stride(0) if v_scale_cache is not None else 0, + stride_vs_slot=v_scale_cache.stride(1) if v_scale_cache is not None else 0, + stride_vs_head=v_scale_cache.stride(2) if v_scale_cache is not None else 0, ) reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 48ecf6b9d..6f8ad8e7d 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,21 +1,70 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import copy from dataclasses import dataclass, fields, replace +from enum import IntEnum from math import prod +from typing import TYPE_CHECKING import torch from typing_extensions import Self -from vllm.config import VllmConfig from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import get_dtype_size logger = init_logger(__name__) +# --------------------------------------------------------------------------- +# KV cache quantization mode +# --------------------------------------------------------------------------- + + +class KVQuantMode(IntEnum): + """KV cache quantization mode. + + Used by attention backends and kernels to dispatch quantization logic + without string matching on ``kv_cache_dtype``. + """ + + NONE = 0 + FP8_PER_TENSOR = 1 # per-tensor scales (current fp8 path) + INT8_PER_TOKEN_HEAD = 2 # per-token-head dynamic scales for int8 + FP8_PER_TOKEN_HEAD = 3 # per-token-head dynamic scales for fp8 + + @property + def is_per_token_head(self) -> bool: + """True for any per-token-head quantization mode.""" + return self >= 2 + + +def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode: + """Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`.""" + if kv_cache_dtype == "int8_per_token_head": + return KVQuantMode.INT8_PER_TOKEN_HEAD + if kv_cache_dtype == "fp8_per_token_head": + return KVQuantMode.FP8_PER_TOKEN_HEAD + if kv_cache_dtype.startswith("fp8"): + return KVQuantMode.FP8_PER_TENSOR + return KVQuantMode.NONE + + +def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: + return get_kv_quant_mode(kv_cache_dtype) != KVQuantMode.NONE + + +def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool: + """Return True if *kv_cache_dtype* needs per-token-head scales.""" + return get_kv_quant_mode(kv_cache_dtype).is_per_token_head + + @dataclass(frozen=True) class KVCacheSpec: """ @@ -66,11 +115,19 @@ class AttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype + kv_quant_mode: KVQuantMode = KVQuantMode.NONE page_size_padded: int | None = None @property def page_size_bytes(self) -> int: real_page_size = self.real_page_size_bytes + # Per-token-head scales are stored in separate tensors managed + # by the attention backend, but the memory is carved from the + # raw KV cache allocation so it must be budgeted here. + if self.kv_quant_mode.is_per_token_head: + real_page_size += ( + 2 * self.block_size * self.num_kv_heads * get_dtype_size(torch.float32) + ) if self.page_size_padded is not None: assert self.page_size_padded >= real_page_size return self.page_size_padded @@ -159,6 +216,7 @@ class FullAttentionSpec(AttentionSpec): head_size=specs[0].head_size, head_size_v=specs[0].head_size_v, dtype=specs[0].dtype, + kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), @@ -220,6 +278,7 @@ class MLAAttentionSpec(FullAttentionSpec): num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, dtype=specs[0].dtype, + kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, cache_dtype_str=cache_dtype_str_set.pop(), ) @@ -352,6 +411,7 @@ class SinkFullAttentionSpec(FullAttentionSpec): head_size_v=specs[0].head_size_v, sink_len=specs[0].sink_len, dtype=specs[0].dtype, + kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),