[Kernel] Add FP8 KV cache support to Triton MLA decode attention (#34597)
Signed-off-by: grimulkan <grimulkan@gmail.com>
This commit is contained in:
@@ -213,5 +213,5 @@ configuration.
|
||||
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
||||
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
||||
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
|
||||
|
||||
@@ -90,3 +90,137 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
)
|
||||
|
||||
assert torch.allclose(o, o1)
|
||||
|
||||
|
||||
def _quantize_to_fp8(tensor: torch.Tensor):
|
||||
"""Quantize a BF16 tensor to FP8 e4m3fn with per-tensor scale.
|
||||
|
||||
Returns (fp8_tensor, scale) where:
|
||||
fp8_tensor ≈ tensor / scale (stored as float8_e4m3fn)
|
||||
tensor ≈ fp8_tensor.to(float32) * scale (dequantized)
|
||||
"""
|
||||
amax = tensor.abs().amax()
|
||||
# float8_e4m3fn max representable value is 448.0
|
||||
scale = (amax / 448.0).clamp(min=1e-12).to(torch.float32)
|
||||
fp8_tensor = (
|
||||
(tensor.to(torch.float32) / scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn)
|
||||
)
|
||||
return fp8_tensor, scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B", [3])
|
||||
@pytest.mark.parametrize("L", [1025])
|
||||
@pytest.mark.parametrize("H_Q", [32])
|
||||
@pytest.mark.parametrize("H_KV", [32, 8])
|
||||
@pytest.mark.parametrize("D_QK", [128, 576])
|
||||
@pytest.mark.parametrize("D_V", [128, 512])
|
||||
@pytest.mark.parametrize("CACHE_SIZE", [16384])
|
||||
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
|
||||
def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
"""Test FP8 KV cache path: quantize K/V to FP8, run kernel with scales,
|
||||
and compare against BF16 reference output."""
|
||||
assert CACHE_SIZE % PAGE_SIZE == 0
|
||||
dtype = torch.bfloat16
|
||||
seq_len = L
|
||||
sm_scale = 1.0 / (D_QK**0.5)
|
||||
num_kv_splits = 8
|
||||
|
||||
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
|
||||
req_to_page = torch.randint(
|
||||
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
|
||||
)
|
||||
req_to_token = req_to_page * PAGE_SIZE
|
||||
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
|
||||
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
|
||||
req_to_token = req_to_token.view(B, -1)
|
||||
req_to_token = req_to_token[:, :seq_len].contiguous()
|
||||
|
||||
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
|
||||
|
||||
# Create BF16 K/V as reference
|
||||
k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
|
||||
v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
# --- BF16 reference ---
|
||||
o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
lse_ref = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
if PAGE_SIZE == 1:
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_bf16,
|
||||
v_bf16,
|
||||
o_ref,
|
||||
lse_ref,
|
||||
req_to_token,
|
||||
b_seq_len=torch.full((B,), seq_len, device="cuda"),
|
||||
attn_logits=attn_logits,
|
||||
num_kv_splits=num_kv_splits,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
k_paged = k_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
|
||||
v_paged = v_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_paged,
|
||||
v_paged,
|
||||
o_ref,
|
||||
lse_ref,
|
||||
req_to_page,
|
||||
b_seq_len=torch.full((B,), seq_len, device="cuda"),
|
||||
attn_logits=attn_logits,
|
||||
num_kv_splits=num_kv_splits,
|
||||
sm_scale=sm_scale,
|
||||
page_size=PAGE_SIZE,
|
||||
)
|
||||
|
||||
# --- FP8 path ---
|
||||
k_fp8, k_scale = _quantize_to_fp8(k_bf16)
|
||||
v_fp8, v_scale = _quantize_to_fp8(v_bf16)
|
||||
|
||||
o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
||||
attn_logits_fp8 = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
if PAGE_SIZE == 1:
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_fp8,
|
||||
v_fp8,
|
||||
o_fp8,
|
||||
lse_fp8,
|
||||
req_to_token,
|
||||
b_seq_len=torch.full((B,), seq_len, device="cuda"),
|
||||
attn_logits=attn_logits_fp8,
|
||||
num_kv_splits=num_kv_splits,
|
||||
sm_scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
else:
|
||||
k_fp8_paged = k_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
|
||||
v_fp8_paged = v_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_fp8_paged,
|
||||
v_fp8_paged,
|
||||
o_fp8,
|
||||
lse_fp8,
|
||||
req_to_page,
|
||||
b_seq_len=torch.full((B,), seq_len, device="cuda"),
|
||||
attn_logits=attn_logits_fp8,
|
||||
num_kv_splits=num_kv_splits,
|
||||
sm_scale=sm_scale,
|
||||
page_size=PAGE_SIZE,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
|
||||
# FP8 tolerances match test_mla_backends.py test_backend_correctness.
|
||||
torch.testing.assert_close(o_ref, o_fp8, atol=5e-1, rtol=1e-2)
|
||||
|
||||
@@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend):
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -108,10 +110,11 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"TritonMLAImpl"
|
||||
)
|
||||
|
||||
# For FP8 KV cache, we dequantize to BF16 on load inside the
|
||||
# Triton kernel. Tell the common layer not to quantize queries
|
||||
# to FP8 — we handle FP8 KV cache with BF16 queries (Mode 1).
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
self.supports_quant_query_input = False
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
@@ -135,9 +138,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
@@ -171,7 +171,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
# Run MQA — always pass layer scales. When KV cache is
|
||||
# BF16 the kernel's `if dtype.is_fp8()` check is a no-op.
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
@@ -184,6 +185,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
num_kv_splits,
|
||||
self.scale,
|
||||
PAGE_SIZE,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
)
|
||||
|
||||
return o, lse
|
||||
|
||||
@@ -31,6 +31,7 @@ It supports page size >= 1.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@@ -74,6 +75,8 @@ def _fwd_kernel_stage1(
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
@@ -109,6 +112,8 @@ def _fwd_kernel_stage1(
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
ks = tl.load(k_scale)
|
||||
vs = tl.load(v_scale)
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
@@ -129,6 +134,8 @@ def _fwd_kernel_stage1(
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
if k.dtype.is_fp8():
|
||||
k = (k.to(tl.float32) * ks).to(q.dtype)
|
||||
qk = tl.sum(q[None, :] * k, 1)
|
||||
qk *= sm_scale
|
||||
|
||||
@@ -147,6 +154,8 @@ def _fwd_kernel_stage1(
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
if v.dtype.is_fp8():
|
||||
v = (v.to(tl.float32) * vs).to(q.dtype)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
@@ -194,6 +203,8 @@ def _decode_att_m_fwd(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
):
|
||||
BLOCK = 64 if not is_hip_ else 8
|
||||
|
||||
@@ -231,6 +242,8 @@ def _decode_att_m_fwd(
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
@@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_group_num: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
@@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
ks = tl.load(k_scale)
|
||||
vs = tl.load(v_scale)
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
@@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
if k.dtype.is_fp8():
|
||||
k = (k.to(tl.float32) * ks).to(q.dtype)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
if BLOCK_DPE > 0:
|
||||
offs_buf_kpe = (
|
||||
@@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
if kpe.dtype.is_fp8():
|
||||
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
|
||||
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||
qk *= sm_scale
|
||||
|
||||
@@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
if v.dtype.is_fp8():
|
||||
v = (v.to(tl.float32) * vs).to(q.dtype)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
@@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lk = k_buffer.shape[-1]
|
||||
@@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd(
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
q_head_num=head_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
@@ -609,6 +636,8 @@ def decode_attention_fwd_normal(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
@@ -621,6 +650,8 @@ def decode_attention_fwd_normal(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||
@@ -640,6 +671,8 @@ def decode_attention_fwd_grouped(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
@@ -652,6 +685,8 @@ def decode_attention_fwd_grouped(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||
@@ -671,8 +706,16 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
page_size=1,
|
||||
logit_cap=0.0,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
|
||||
if k_scale is None:
|
||||
k_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
|
||||
if v_scale is None:
|
||||
v_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
|
||||
|
||||
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||
|
||||
if kv_group_num == 1:
|
||||
@@ -690,6 +733,8 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
@@ -706,4 +751,6 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user