[Kernel] Add FP8 KV cache support to Triton MLA decode attention (#34597)

Signed-off-by: grimulkan <grimulkan@gmail.com>
This commit is contained in:
grimulkan
2026-03-12 10:32:34 -05:00
committed by GitHub
parent abcffbba8c
commit a1257fd1ea
4 changed files with 192 additions and 8 deletions

View File

@@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend):
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]
@classmethod
@@ -108,10 +110,11 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"TritonMLAImpl"
)
# For FP8 KV cache, we dequantize to BF16 on load inside the
# Triton kernel. Tell the common layer not to quantize queries
# to FP8 — we handle FP8 KV cache with BF16 queries (Mode 1).
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported"
)
self.supports_quant_query_input = False
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
@@ -135,9 +138,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
if type(q) is tuple:
q = torch.cat(q, dim=-1)
@@ -171,7 +171,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
# Run MQA — always pass layer scales. When KV cache is
# BF16 the kernel's `if dtype.is_fp8()` check is a no-op.
decode_attention_fwd(
q,
kv_c_and_k_pe_cache,
@@ -184,6 +185,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
num_kv_splits,
self.scale,
PAGE_SIZE,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
)
return o, lse

View File

@@ -31,6 +31,7 @@ It supports page size >= 1.
import logging
import torch
from packaging import version
from vllm.platforms import current_platform
@@ -74,6 +75,8 @@ def _fwd_kernel_stage1(
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
k_scale,
v_scale,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
@@ -109,6 +112,8 @@ def _fwd_kernel_stage1(
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
ks = tl.load(k_scale)
vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
@@ -129,6 +134,8 @@ def _fwd_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
other=0.0,
)
if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale
@@ -147,6 +154,8 @@ def _fwd_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
re_scale = tl.exp(e_max - n_e_max)
@@ -194,6 +203,8 @@ def _decode_att_m_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
):
BLOCK = 64 if not is_hip_ else 8
@@ -231,6 +242,8 @@ def _decode_att_m_fwd(
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
k_scale,
v_scale,
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
@@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1(
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
k_scale,
v_scale,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
@@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1(
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
ks = tl.load(k_scale)
vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
@@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
@@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
)
if kpe.dtype.is_fp8():
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
@@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1(
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
@@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
@@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd(
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
k_scale,
v_scale,
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
@@ -609,6 +636,8 @@ def decode_attention_fwd_normal(
sm_scale,
page_size,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
_decode_att_m_fwd(
q,
@@ -621,6 +650,8 @@ def decode_attention_fwd_normal(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
@@ -640,6 +671,8 @@ def decode_attention_fwd_grouped(
sm_scale,
page_size,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
_decode_grouped_att_m_fwd(
q,
@@ -652,6 +685,8 @@ def decode_attention_fwd_grouped(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
@@ -671,8 +706,16 @@ def decode_attention_fwd(
sm_scale,
page_size=1,
logit_cap=0.0,
k_scale=None,
v_scale=None,
):
assert num_kv_splits == attn_logits.shape[2]
if k_scale is None:
k_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
if v_scale is None:
v_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
kv_group_num = q.shape[1] // v_buffer.shape[-2]
if kv_group_num == 1:
@@ -690,6 +733,8 @@ def decode_attention_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)
else:
# GQA/MQA/MLA
@@ -706,4 +751,6 @@ def decode_attention_fwd(
sm_scale,
page_size,
logit_cap,
k_scale,
v_scale,
)