[Kernel] Add FP8 KV cache support to Triton MLA decode attention (#34597)

Signed-off-by: grimulkan <grimulkan@gmail.com>
This commit is contained in:
grimulkan
2026-03-12 10:32:34 -05:00
committed by GitHub
parent abcffbba8c
commit a1257fd1ea
4 changed files with 192 additions and 8 deletions

View File

@@ -213,5 +213,5 @@ configuration.
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |

View File

@@ -90,3 +90,137 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
)
assert torch.allclose(o, o1)
def _quantize_to_fp8(tensor: torch.Tensor):
"""Quantize a BF16 tensor to FP8 e4m3fn with per-tensor scale.
Returns (fp8_tensor, scale) where:
fp8_tensor ≈ tensor / scale (stored as float8_e4m3fn)
tensor ≈ fp8_tensor.to(float32) * scale (dequantized)
"""
amax = tensor.abs().amax()
# float8_e4m3fn max representable value is 448.0
scale = (amax / 448.0).clamp(min=1e-12).to(torch.float32)
fp8_tensor = (
(tensor.to(torch.float32) / scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn)
)
return fp8_tensor, scale
@pytest.mark.parametrize("B", [3])
@pytest.mark.parametrize("L", [1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
"""Test FP8 KV cache path: quantize K/V to FP8, run kernel with scales,
and compare against BF16 reference output."""
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint(
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
)
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# Create BF16 K/V as reference
k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# --- BF16 reference ---
o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
lse_ref = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda"
)
if PAGE_SIZE == 1:
decode_attention_fwd(
q,
k_bf16,
v_bf16,
o_ref,
lse_ref,
req_to_token,
b_seq_len=torch.full((B,), seq_len, device="cuda"),
attn_logits=attn_logits,
num_kv_splits=num_kv_splits,
sm_scale=sm_scale,
)
else:
k_paged = k_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_paged = v_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
decode_attention_fwd(
q,
k_paged,
v_paged,
o_ref,
lse_ref,
req_to_page,
b_seq_len=torch.full((B,), seq_len, device="cuda"),
attn_logits=attn_logits,
num_kv_splits=num_kv_splits,
sm_scale=sm_scale,
page_size=PAGE_SIZE,
)
# --- FP8 path ---
k_fp8, k_scale = _quantize_to_fp8(k_bf16)
v_fp8, v_scale = _quantize_to_fp8(v_bf16)
o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
attn_logits_fp8 = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda"
)
if PAGE_SIZE == 1:
decode_attention_fwd(
q,
k_fp8,
v_fp8,
o_fp8,
lse_fp8,
req_to_token,
b_seq_len=torch.full((B,), seq_len, device="cuda"),
attn_logits=attn_logits_fp8,
num_kv_splits=num_kv_splits,
sm_scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
)
else:
k_fp8_paged = k_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_fp8_paged = v_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
decode_attention_fwd(
q,
k_fp8_paged,
v_fp8_paged,
o_fp8,
lse_fp8,
req_to_page,
b_seq_len=torch.full((B,), seq_len, device="cuda"),
attn_logits=attn_logits_fp8,
num_kv_splits=num_kv_splits,
sm_scale=sm_scale,
page_size=PAGE_SIZE,
k_scale=k_scale,
v_scale=v_scale,
)
# FP8 tolerances match test_mla_backends.py test_backend_correctness.
torch.testing.assert_close(o_ref, o_fp8, atol=5e-1, rtol=1e-2)

View File

@@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend):
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]
@classmethod
@@ -108,10 +110,11 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"TritonMLAImpl"
)
# For FP8 KV cache, we dequantize to BF16 on load inside the
# Triton kernel. Tell the common layer not to quantize queries
# to FP8 — we handle FP8 KV cache with BF16 queries (Mode 1).
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported"
)
self.supports_quant_query_input = False
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
@@ -135,9 +138,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
if type(q) is tuple:
q = torch.cat(q, dim=-1)
@@ -171,7 +171,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
# Run MQA — always pass layer scales. When KV cache is
# BF16 the kernel's `if dtype.is_fp8()` check is a no-op.
decode_attention_fwd(
q,
kv_c_and_k_pe_cache,
@@ -184,6 +185,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
num_kv_splits,
self.scale,
PAGE_SIZE,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
)
return o, lse

View File

@@ -31,6 +31,7 @@ It supports page size >= 1.
import logging
import torch
from packaging import version
from vllm.platforms import current_platform
@@ -74,6 +75,8 @@ def _fwd_kernel_stage1(
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
k_scale,
v_scale,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
@@ -109,6 +112,8 @@ def _fwd_kernel_stage1(
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
ks = tl.load(k_scale)
vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
@@ -129,6 +134,8 @@ def _fwd_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
other=0.0,
)
if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale
@@ -147,6 +154,8 @@ def _fwd_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
re_scale = tl.exp(e_max - n_e_max)
@@ -194,6 +203,8 @@ def _decode_att_m_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
):
BLOCK = 64 if not is_hip_ else 8
@@ -231,6 +242,8 @@ def _decode_att_m_fwd(
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
k_scale,
v_scale,
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
@@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1(
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
k_scale,
v_scale,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
@@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1(
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
ks = tl.load(k_scale)
vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
@@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
@@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
)
if kpe.dtype.is_fp8():
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
@@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
@@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
@@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd(
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
k_scale,
v_scale,
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
@@ -609,6 +636,8 @@ def decode_attention_fwd_normal(
sm_scale,
page_size,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
_decode_att_m_fwd(
q,
@@ -621,6 +650,8 @@ def decode_attention_fwd_normal(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
@@ -640,6 +671,8 @@ def decode_attention_fwd_grouped(
sm_scale,
page_size,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
_decode_grouped_att_m_fwd(
q,
@@ -652,6 +685,8 @@ def decode_attention_fwd_grouped(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
@@ -671,8 +706,16 @@ def decode_attention_fwd(
sm_scale,
page_size=1,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
assert num_kv_splits == attn_logits.shape[2]
if k_scale is None:
k_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
if v_scale is None:
v_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
kv_group_num = q.shape[1] // v_buffer.shape[-2]
if kv_group_num == 1:
@@ -690,6 +733,8 @@ def decode_attention_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
else:
# GQA/MQA/MLA
@@ -706,4 +751,6 @@ def decode_attention_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)