# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch from vllm.utils.math_utils import cdiv from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd @pytest.mark.parametrize("B", [3, 5]) @pytest.mark.parametrize("L", [1027, 1025]) @pytest.mark.parametrize("H_Q", [32]) @pytest.mark.parametrize("H_KV", [32, 8]) @pytest.mark.parametrize("D_QK", [128, 192, 576]) @pytest.mark.parametrize("D_V", [128, 512]) @pytest.mark.parametrize("CACHE_SIZE", [16384]) @pytest.mark.parametrize("PAGE_SIZE", [1, 16]) def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): assert CACHE_SIZE % PAGE_SIZE == 0 dtype = torch.bfloat16 seq_len = L # This represents the number of tokens already in the sequence sm_scale = 1.0 / (D_QK**0.5) num_kv_splits = 8 num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) req_to_page = torch.randint( 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" ) req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token[:, :seq_len].contiguous() # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") # k_buffer and v_buffer represent all previous tokens # Page size is 1. k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") # o will have the same shape as q o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda", ) # Call the original implementation. decode_attention_fwd( q, k_buffer, v_buffer, o, lse, req_to_token, b_seq_len, attn_logits, num_kv_splits, sm_scale, ) # Page size can be larger than 1. k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) o1 = torch.zeros_like(o) lse1 = torch.zeros_like(lse) decode_attention_fwd( q, k_buffer, v_buffer, o1, lse1, req_to_page, b_seq_len, attn_logits, num_kv_splits, sm_scale, PAGE_SIZE, ) assert torch.allclose(o, o1) def _quantize_to_fp8(tensor: torch.Tensor): """Quantize a BF16 tensor to FP8 e4m3fn with per-tensor scale. Returns (fp8_tensor, scale) where: fp8_tensor ≈ tensor / scale (stored as float8_e4m3fn) tensor ≈ fp8_tensor.to(float32) * scale (dequantized) """ amax = tensor.abs().amax() # float8_e4m3fn max representable value is 448.0 scale = (amax / 448.0).clamp(min=1e-12).to(torch.float32) fp8_tensor = ( (tensor.to(torch.float32) / scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn) ) return fp8_tensor, scale @pytest.mark.parametrize("B", [3]) @pytest.mark.parametrize("L", [1025]) @pytest.mark.parametrize("H_Q", [32]) @pytest.mark.parametrize("H_KV", [32, 8]) @pytest.mark.parametrize("D_QK", [128, 576]) @pytest.mark.parametrize("D_V", [128, 512]) @pytest.mark.parametrize("CACHE_SIZE", [16384]) @pytest.mark.parametrize("PAGE_SIZE", [1, 16]) def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): """Test FP8 KV cache path: quantize K/V to FP8, run kernel with scales, and compare against BF16 reference output.""" assert CACHE_SIZE % PAGE_SIZE == 0 dtype = torch.bfloat16 seq_len = L sm_scale = 1.0 / (D_QK**0.5) num_kv_splits = 8 num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) req_to_page = torch.randint( 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" ) req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token[:, :seq_len].contiguous() q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") # Create BF16 K/V as reference k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") # --- BF16 reference --- o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") lse_ref = torch.zeros(B, H_Q, dtype=dtype, device="cuda") attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda" ) if PAGE_SIZE == 1: decode_attention_fwd( q, k_bf16, v_bf16, o_ref, lse_ref, req_to_token, b_seq_len=torch.full((B,), seq_len, device="cuda"), attn_logits=attn_logits, num_kv_splits=num_kv_splits, sm_scale=sm_scale, ) else: k_paged = k_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) v_paged = v_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) decode_attention_fwd( q, k_paged, v_paged, o_ref, lse_ref, req_to_page, b_seq_len=torch.full((B,), seq_len, device="cuda"), attn_logits=attn_logits, num_kv_splits=num_kv_splits, sm_scale=sm_scale, page_size=PAGE_SIZE, ) # --- FP8 path --- k_fp8, k_scale = _quantize_to_fp8(k_bf16) v_fp8, v_scale = _quantize_to_fp8(v_bf16) o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device="cuda") attn_logits_fp8 = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda" ) if PAGE_SIZE == 1: decode_attention_fwd( q, k_fp8, v_fp8, o_fp8, lse_fp8, req_to_token, b_seq_len=torch.full((B,), seq_len, device="cuda"), attn_logits=attn_logits_fp8, num_kv_splits=num_kv_splits, sm_scale=sm_scale, k_scale=k_scale, v_scale=v_scale, ) else: k_fp8_paged = k_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) v_fp8_paged = v_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) decode_attention_fwd( q, k_fp8_paged, v_fp8_paged, o_fp8, lse_fp8, req_to_page, b_seq_len=torch.full((B,), seq_len, device="cuda"), attn_logits=attn_logits_fp8, num_kv_splits=num_kv_splits, sm_scale=sm_scale, page_size=PAGE_SIZE, k_scale=k_scale, v_scale=v_scale, ) # FP8 tolerances match test_mla_backends.py test_backend_correctness. torch.testing.assert_close(o_ref, o_fp8, atol=5e-1, rtol=1e-2)