# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for per-token-head KV cache quantization (INT8 and FP8). Covers: - Per-token-head Triton reshape-and-cache kernel - Round-trip quantize/dequantize accuracy - process_weights_after_loading early-return path - End-to-end integration with Triton unified attention kernel Run: pytest tests/quantization/test_per_token_kv_cache.py -v -s """ import random from dataclasses import dataclass from unittest.mock import MagicMock import pytest import torch from vllm.model_executor.layers.quantization.utils.quant_utils import ( get_fp8_min_max, ) from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed from vllm.v1.kv_cache_interface import KVQuantMode, is_quantized_kv_cache # Skip entire module if no CUDA/ROCm GPU available pytestmark = [ pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Per-token-head KV cache tests require CUDA or ROCm GPU.", ), ] # --------------------------------------------------------------------------- # Test parameters # --------------------------------------------------------------------------- NUM_TOKENS = [1, 7, 42] NUM_KV_HEADS = [1, 4, 8] HEAD_SIZES = [64, 128] BLOCK_SIZES = [16] SEEDS = [0] # Platform-dependent FP8 dtype and range FP8_DTYPE = current_platform.fp8_dtype() FP8_MIN, FP8_MAX = get_fp8_min_max() # --------------------------------------------------------------------------- # Per-dtype quantization config # --------------------------------------------------------------------------- @dataclass(frozen=True) class QuantConfig: """Quantization parameters for a given cache dtype.""" cache_dtype: torch.dtype # torch.int8 or FP8_DTYPE kv_cache_dtype_str: str # "int8_per_token_head" or "fp8_per_token_head" quant_max: float quant_min: float kv_quant_mode: KVQuantMode # INT8 Triton stores truncate; FP8 hardware casts round. uses_trunc: bool INT8_CONFIG = QuantConfig( cache_dtype=torch.int8, kv_cache_dtype_str="int8_per_token_head", quant_max=127.0, quant_min=-128.0, kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD, uses_trunc=True, ) FP8_CONFIG = QuantConfig( cache_dtype=FP8_DTYPE, kv_cache_dtype_str="fp8_per_token_head", quant_max=FP8_MAX, quant_min=FP8_MIN, kv_quant_mode=KVQuantMode.FP8_PER_TOKEN_HEAD, uses_trunc=False, ) QUANT_CONFIGS = [INT8_CONFIG, FP8_CONFIG] @pytest.fixture(params=QUANT_CONFIGS, ids=["int8", "fp8"]) def qcfg(request) -> QuantConfig: return request.param # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _quantize_per_token_head_ref( data: torch.Tensor, # [num_tokens, num_heads, head_size] cfg: QuantConfig, ) -> tuple[torch.Tensor, torch.Tensor]: """Reference per-token-head quantization (one scale per token per head). Returns (quantized, scales) where scales is [num_tokens, num_heads]. """ absmax = data.float().abs().amax(dim=2) # [num_tokens, num_heads] scales = (absmax / cfg.quant_max).clamp(min=1e-6) scaled = data.float() * (1.0 / scales[:, :, None]) if cfg.uses_trunc: q = scaled.round().clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype) else: q = scaled.clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype) return q, scales # =========================================================================== # 1. is_quantized_kv_cache / get_kv_quant_mode # =========================================================================== class TestIsQuantizedKvCache: def test_fp8_variants(self): assert is_quantized_kv_cache("fp8") assert is_quantized_kv_cache("fp8_e4m3") assert is_quantized_kv_cache("fp8_e5m2") def test_int8_per_token_head(self): assert is_quantized_kv_cache("int8_per_token_head") def test_fp8_per_token_head(self): assert is_quantized_kv_cache("fp8_per_token_head") def test_auto(self): assert not is_quantized_kv_cache("auto") def test_bfloat16(self): assert not is_quantized_kv_cache("bfloat16") def test_kv_quant_mode_int8(self): from vllm.v1.kv_cache_interface import get_kv_quant_mode assert ( get_kv_quant_mode("int8_per_token_head") == KVQuantMode.INT8_PER_TOKEN_HEAD ) def test_kv_quant_mode_fp8(self): from vllm.v1.kv_cache_interface import get_kv_quant_mode assert get_kv_quant_mode("fp8_per_token_head") == KVQuantMode.FP8_PER_TOKEN_HEAD # =========================================================================== # 2. Triton per-token-head kernel (reshape-and-cache) # =========================================================================== @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_heads", NUM_KV_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_reshape_and_cache_per_token_head( qcfg: QuantConfig, num_tokens: int, num_heads: int, head_size: int, block_size: int, seed: int, ): """Test triton_reshape_and_cache_flash_per_token_head_quant kernel.""" from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash_per_token_head_quant, ) set_random_seed(seed) torch.set_default_device("cuda") num_blocks = (num_tokens + block_size - 1) // block_size + 4 key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) key_cache = torch.zeros( num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype ) value_cache = torch.zeros( num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype ) k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) num_slots = block_size * num_blocks slot_mapping = torch.tensor( random.sample(range(num_slots), num_tokens), dtype=torch.long ) triton_reshape_and_cache_flash_per_token_head_quant( key, value, key_cache, value_cache, k_scale_cache, v_scale_cache, slot_mapping, ) # Reference ref_k_quant, ref_k_scales = _quantize_per_token_head_ref(key, qcfg) ref_v_quant, ref_v_scales = _quantize_per_token_head_ref(value, qcfg) # Compare dequantized values rather than raw quantized values. # Triton and PyTorch reductions can differ at FP8 rounding boundaries # (up to 32 in quantized domain for fp8_e4m3), but the dequantized # error is bounded by the scale. for i, slot in enumerate(slot_mapping.tolist()): blk = slot // block_size off = slot % block_size actual_k_scale = k_scale_cache[blk, off] # [num_heads] k_deq = key_cache[blk, off].float() * actual_k_scale[:, None] k_ref_deq = key[i].float() torch.testing.assert_close( k_deq, k_ref_deq, atol=0.1, rtol=0.1, ) actual_v_scale = v_scale_cache[blk, off] # [num_heads] v_deq = value_cache[blk, off].float() * actual_v_scale[:, None] v_ref_deq = value[i].float() torch.testing.assert_close( v_deq, v_ref_deq, atol=0.1, rtol=0.1, ) # Per-head scales: [num_heads] torch.testing.assert_close( k_scale_cache[blk, off], ref_k_scales[i], atol=1e-4, rtol=1e-3 ) torch.testing.assert_close( v_scale_cache[blk, off], ref_v_scales[i], atol=1e-4, rtol=1e-3 ) # =========================================================================== # 3. Per-token-head round-trip accuracy (quantize -> dequantize) # =========================================================================== @pytest.mark.parametrize("num_tokens", [1, 16]) @pytest.mark.parametrize("num_heads", [4]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("block_size", [16]) @torch.inference_mode() def test_per_token_head_round_trip_accuracy( qcfg: QuantConfig, num_tokens: int, num_heads: int, head_size: int, block_size: int, ): """Verify per-token-head round-trip: kernel dequant matches reference. INT8: Triton truncates on float->int8 store. FP8: hardware cast (clamp then cast). """ from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash_per_token_head_quant, ) torch.set_default_device("cuda") set_random_seed(42) num_blocks = (num_tokens + block_size - 1) // block_size + 2 key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5 value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5 key_cache = torch.zeros( num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype ) value_cache = torch.zeros( num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype ) k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) slot_mapping = torch.arange(num_tokens, dtype=torch.long) triton_reshape_and_cache_flash_per_token_head_quant( key, value, key_cache, value_cache, k_scale_cache, v_scale_cache, slot_mapping, ) for i in range(num_tokens): blk = i // block_size off = i % block_size for label, data, cache, sc in [ ("key", key, key_cache, k_scale_cache), ("val", value, value_cache, v_scale_cache), ]: for h in range(num_heads): orig = data[i, h].float() # [head_size] actual_q = cache[blk, off, h] actual_sc = sc[blk, off, h] actual_deq = actual_q.float() * actual_sc # Round-trip: dequantized should be close to original torch.testing.assert_close( actual_deq, orig, atol=0.1, rtol=0.1, ) # =========================================================================== # 4. Negative slot mapping (padding tokens should be skipped) # =========================================================================== @torch.inference_mode() def test_per_token_head_negative_slot_skipped(qcfg: QuantConfig): """Tokens with slot_mapping=-1 should leave the cache unchanged.""" from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash_per_token_head_quant, ) torch.set_default_device("cuda") num_tokens = 4 num_heads = 2 head_size = 64 block_size = 16 num_blocks = 2 key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) key_cache = torch.zeros( num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype ) value_cache = torch.zeros( num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype ) k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32) slot_mapping = torch.tensor([0, -1, 1, -1], dtype=torch.long) key_cache_before = key_cache.clone() val_cache_before = value_cache.clone() triton_reshape_and_cache_flash_per_token_head_quant( key, value, key_cache, value_cache, k_scale_cache, v_scale_cache, slot_mapping, ) # Slots 0 and 1 should have been written (tokens 0 and 2) assert not torch.equal(key_cache[0, 0], key_cache_before[0, 0]) assert not torch.equal(key_cache[0, 1], key_cache_before[0, 1]) assert not torch.equal(value_cache[0, 0], val_cache_before[0, 0]) # All other slots should be unchanged assert torch.equal(key_cache[0, 2:], key_cache_before[0, 2:]) assert torch.equal(key_cache[1], key_cache_before[1]) assert torch.equal(value_cache[0, 2:], val_cache_before[0, 2:]) # =========================================================================== # 5. process_weights_after_loading -- per-token-head early return # =========================================================================== @pytest.mark.parametrize( "kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"] ) def test_process_weights_sets_placeholder_scales(kv_cache_dtype: str): """Per-token-head should set _k_scale=1.0, _v_scale=1.0 and delete checkpoint attrs.""" from vllm.model_executor.layers.quantization.kv_cache import ( BaseKVCacheMethod, ) layer = MagicMock() layer.kv_cache_dtype = kv_cache_dtype layer.calculate_kv_scales = False layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) layer._k_scale = torch.tensor(0.0) layer._v_scale = torch.tensor(0.0) layer._k_scale_float = 0.0 layer._v_scale_float = 0.0 method = BaseKVCacheMethod.__new__(BaseKVCacheMethod) method.quant_config = MagicMock() method.process_weights_after_loading(layer) assert layer._k_scale_float == 1.0 assert layer._v_scale_float == 1.0 assert not hasattr(layer, "k_scale") assert not hasattr(layer, "v_scale") assert not hasattr(layer, "q_scale") assert not hasattr(layer, "prob_scale") # =========================================================================== # 6. Triton unified_attention -- per-token-head scale cache (INT8 and FP8) # =========================================================================== @pytest.mark.parametrize( "seq_lens", [ [(1, 128)], [(1, 64), (1, 32)], ], ) @pytest.mark.parametrize("num_heads", [(4, 4)]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("block_size", [16]) @torch.inference_mode() def test_triton_unified_attention_per_token_head_scale( qcfg: QuantConfig, seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], head_size: int, block_size: int, ): """End-to-end: quantized KV with per-token-head scale caches.""" from vllm.utils.math_utils import next_power_of_2 from vllm.v1.attention.ops.triton_unified_attention import unified_attention torch.set_default_device("cuda") set_random_seed(0) num_seqs = len(seq_lens) query_lens = [s[0] for s in seq_lens] kv_lens = [s[1] for s in seq_lens] num_query_heads, num_kv_heads = num_heads max_query_len = max(query_lens) max_kv_len = max(kv_lens) scale = head_size**-0.5 num_blocks = 2048 query = torch.randn( sum(query_lens), num_query_heads, head_size, dtype=torch.bfloat16 ) key_cache_bf16 = torch.randn( num_blocks, block_size, num_kv_heads, head_size, dtype=torch.bfloat16 ) value_cache_bf16 = torch.randn_like(key_cache_bf16) # Per-token-head quantization: one scale per (block, slot, head) k_absmax = key_cache_bf16.float().abs().amax(dim=-1) # [..., num_kv_heads] v_absmax = value_cache_bf16.float().abs().amax(dim=-1) k_scale_cache = (k_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32) v_scale_cache = (v_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32) scaled_k = key_cache_bf16.float() / k_scale_cache[:, :, :, None] scaled_v = value_cache_bf16.float() / v_scale_cache[:, :, :, None] if qcfg.uses_trunc: key_cache_q = ( scaled_k.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype) ) value_cache_q = ( scaled_v.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype) ) else: key_cache_q = scaled_k.clamp(qcfg.quant_min, qcfg.quant_max).to( qcfg.cache_dtype ) value_cache_q = scaled_v.clamp(qcfg.quant_min, qcfg.quant_max).to( qcfg.cache_dtype ) # Dequantized reference key_cache_deq = key_cache_q.float() * k_scale_cache[:, :, :, None] value_cache_deq = value_cache_q.float() * v_scale_cache[:, :, :, None] cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( dim=0, dtype=torch.int32 ) kv_lens_t = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size block_tables = torch.randint( 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 ) head_size_padded = next_power_of_2(head_size) seq_threshold_3D = 0 num_par_softmax_segments = 16 softmax_segm_output = torch.empty( (seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded), dtype=torch.float32, ) softmax_segm_max = torch.empty( (seq_threshold_3D, num_query_heads, num_par_softmax_segments), dtype=torch.float32, ) softmax_segm_expsum = torch.empty( (seq_threshold_3D, num_query_heads, num_par_softmax_segments), dtype=torch.float32, ) output_q = torch.empty_like(query) unified_attention( q=query, k=key_cache_q, v=value_cache_q, out=output_q, cu_seqlens_q=cu_query_lens, seqused_k=kv_lens_t, max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len, softmax_scale=scale, causal=True, window_size=(-1, -1), block_table=block_tables, softcap=0, q_descale=None, k_descale=None, v_descale=None, seq_threshold_3D=seq_threshold_3D, num_par_softmax_segments=num_par_softmax_segments, softmax_segm_output=softmax_segm_output, softmax_segm_max=softmax_segm_max, softmax_segm_expsum=softmax_segm_expsum, kv_quant_mode=qcfg.kv_quant_mode, k_scale_cache=k_scale_cache, v_scale_cache=v_scale_cache, ) output_ref = torch.empty_like(query) unified_attention( q=query, k=key_cache_deq.to(torch.bfloat16), v=value_cache_deq.to(torch.bfloat16), out=output_ref, cu_seqlens_q=cu_query_lens, seqused_k=kv_lens_t, max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len, softmax_scale=scale, causal=True, window_size=(-1, -1), block_table=block_tables, softcap=0, q_descale=None, k_descale=None, v_descale=None, seq_threshold_3D=seq_threshold_3D, num_par_softmax_segments=num_par_softmax_segments, softmax_segm_output=softmax_segm_output, softmax_segm_max=softmax_segm_max, softmax_segm_expsum=softmax_segm_expsum, ) torch.testing.assert_close(output_q, output_ref, atol=5e-2, rtol=5e-2)