227 lines
7.3 KiB
Python
227 lines
7.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.utils.math_utils import cdiv
|
|
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
|
|
|
|
|
@pytest.mark.parametrize("B", [3, 5])
|
|
@pytest.mark.parametrize("L", [1027, 1025])
|
|
@pytest.mark.parametrize("H_Q", [32])
|
|
@pytest.mark.parametrize("H_KV", [32, 8])
|
|
@pytest.mark.parametrize("D_QK", [128, 192, 576])
|
|
@pytest.mark.parametrize("D_V", [128, 512])
|
|
@pytest.mark.parametrize("CACHE_SIZE", [16384])
|
|
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
|
|
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
|
assert CACHE_SIZE % PAGE_SIZE == 0
|
|
dtype = torch.bfloat16
|
|
seq_len = L # This represents the number of tokens already in the sequence
|
|
sm_scale = 1.0 / (D_QK**0.5)
|
|
num_kv_splits = 8
|
|
|
|
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
|
|
req_to_page = torch.randint(
|
|
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
|
|
)
|
|
req_to_token = req_to_page * PAGE_SIZE
|
|
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
|
|
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
|
|
req_to_token = req_to_token.view(B, -1)
|
|
req_to_token = req_to_token[:, :seq_len].contiguous()
|
|
|
|
# q represents the new token being generated, one per batch
|
|
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
|
|
|
|
# k_buffer and v_buffer represent all previous tokens
|
|
# Page size is 1.
|
|
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
|
|
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
|
|
|
|
# o will have the same shape as q
|
|
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
|
|
|
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
|
|
|
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
|
|
|
attn_logits = torch.empty(
|
|
(B, H_Q, num_kv_splits, D_V + 1),
|
|
dtype=torch.float32,
|
|
device="cuda",
|
|
)
|
|
|
|
# Call the original implementation.
|
|
decode_attention_fwd(
|
|
q,
|
|
k_buffer,
|
|
v_buffer,
|
|
o,
|
|
lse,
|
|
req_to_token,
|
|
b_seq_len,
|
|
attn_logits,
|
|
num_kv_splits,
|
|
sm_scale,
|
|
)
|
|
|
|
# Page size can be larger than 1.
|
|
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
|
|
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
|
|
|
o1 = torch.zeros_like(o)
|
|
lse1 = torch.zeros_like(lse)
|
|
|
|
decode_attention_fwd(
|
|
q,
|
|
k_buffer,
|
|
v_buffer,
|
|
o1,
|
|
lse1,
|
|
req_to_page,
|
|
b_seq_len,
|
|
attn_logits,
|
|
num_kv_splits,
|
|
sm_scale,
|
|
PAGE_SIZE,
|
|
)
|
|
|
|
assert torch.allclose(o, o1)
|
|
|
|
|
|
def _quantize_to_fp8(tensor: torch.Tensor):
|
|
"""Quantize a BF16 tensor to FP8 e4m3fn with per-tensor scale.
|
|
|
|
Returns (fp8_tensor, scale) where:
|
|
fp8_tensor ≈ tensor / scale (stored as float8_e4m3fn)
|
|
tensor ≈ fp8_tensor.to(float32) * scale (dequantized)
|
|
"""
|
|
amax = tensor.abs().amax()
|
|
# float8_e4m3fn max representable value is 448.0
|
|
scale = (amax / 448.0).clamp(min=1e-12).to(torch.float32)
|
|
fp8_tensor = (
|
|
(tensor.to(torch.float32) / scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn)
|
|
)
|
|
return fp8_tensor, scale
|
|
|
|
|
|
@pytest.mark.parametrize("B", [3])
|
|
@pytest.mark.parametrize("L", [1025])
|
|
@pytest.mark.parametrize("H_Q", [32])
|
|
@pytest.mark.parametrize("H_KV", [32, 8])
|
|
@pytest.mark.parametrize("D_QK", [128, 576])
|
|
@pytest.mark.parametrize("D_V", [128, 512])
|
|
@pytest.mark.parametrize("CACHE_SIZE", [16384])
|
|
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
|
|
def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
|
"""Test FP8 KV cache path: quantize K/V to FP8, run kernel with scales,
|
|
and compare against BF16 reference output."""
|
|
assert CACHE_SIZE % PAGE_SIZE == 0
|
|
dtype = torch.bfloat16
|
|
seq_len = L
|
|
sm_scale = 1.0 / (D_QK**0.5)
|
|
num_kv_splits = 8
|
|
|
|
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
|
|
req_to_page = torch.randint(
|
|
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
|
|
)
|
|
req_to_token = req_to_page * PAGE_SIZE
|
|
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
|
|
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
|
|
req_to_token = req_to_token.view(B, -1)
|
|
req_to_token = req_to_token[:, :seq_len].contiguous()
|
|
|
|
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
|
|
|
|
# Create BF16 K/V as reference
|
|
k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
|
|
v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
|
|
|
|
# --- BF16 reference ---
|
|
o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
|
lse_ref = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
|
attn_logits = torch.empty(
|
|
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda"
|
|
)
|
|
|
|
if PAGE_SIZE == 1:
|
|
decode_attention_fwd(
|
|
q,
|
|
k_bf16,
|
|
v_bf16,
|
|
o_ref,
|
|
lse_ref,
|
|
req_to_token,
|
|
b_seq_len=torch.full((B,), seq_len, device="cuda"),
|
|
attn_logits=attn_logits,
|
|
num_kv_splits=num_kv_splits,
|
|
sm_scale=sm_scale,
|
|
)
|
|
else:
|
|
k_paged = k_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
|
|
v_paged = v_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
|
decode_attention_fwd(
|
|
q,
|
|
k_paged,
|
|
v_paged,
|
|
o_ref,
|
|
lse_ref,
|
|
req_to_page,
|
|
b_seq_len=torch.full((B,), seq_len, device="cuda"),
|
|
attn_logits=attn_logits,
|
|
num_kv_splits=num_kv_splits,
|
|
sm_scale=sm_scale,
|
|
page_size=PAGE_SIZE,
|
|
)
|
|
|
|
# --- FP8 path ---
|
|
k_fp8, k_scale = _quantize_to_fp8(k_bf16)
|
|
v_fp8, v_scale = _quantize_to_fp8(v_bf16)
|
|
|
|
o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
|
lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
|
attn_logits_fp8 = torch.empty(
|
|
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda"
|
|
)
|
|
|
|
if PAGE_SIZE == 1:
|
|
decode_attention_fwd(
|
|
q,
|
|
k_fp8,
|
|
v_fp8,
|
|
o_fp8,
|
|
lse_fp8,
|
|
req_to_token,
|
|
b_seq_len=torch.full((B,), seq_len, device="cuda"),
|
|
attn_logits=attn_logits_fp8,
|
|
num_kv_splits=num_kv_splits,
|
|
sm_scale=sm_scale,
|
|
k_scale=k_scale,
|
|
v_scale=v_scale,
|
|
)
|
|
else:
|
|
k_fp8_paged = k_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
|
|
v_fp8_paged = v_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
|
decode_attention_fwd(
|
|
q,
|
|
k_fp8_paged,
|
|
v_fp8_paged,
|
|
o_fp8,
|
|
lse_fp8,
|
|
req_to_page,
|
|
b_seq_len=torch.full((B,), seq_len, device="cuda"),
|
|
attn_logits=attn_logits_fp8,
|
|
num_kv_splits=num_kv_splits,
|
|
sm_scale=sm_scale,
|
|
page_size=PAGE_SIZE,
|
|
k_scale=k_scale,
|
|
v_scale=v_scale,
|
|
)
|
|
|
|
# FP8 tolerances match test_mla_backends.py test_backend_correctness.
|
|
torch.testing.assert_close(o_ref, o_fp8, atol=5e-1, rtol=1e-2)
|