[Kernel] Add FP8 KV cache support to Triton MLA decode attention (#34597)
Signed-off-by: grimulkan <grimulkan@gmail.com>
This commit is contained in:
@@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend):
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -108,10 +110,11 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"TritonMLAImpl"
|
||||
)
|
||||
|
||||
# For FP8 KV cache, we dequantize to BF16 on load inside the
|
||||
# Triton kernel. Tell the common layer not to quantize queries
|
||||
# to FP8 — we handle FP8 KV cache with BF16 queries (Mode 1).
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
self.supports_quant_query_input = False
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
@@ -135,9 +138,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
@@ -171,7 +171,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
# Run MQA — always pass layer scales. When KV cache is
|
||||
# BF16 the kernel's `if dtype.is_fp8()` check is a no-op.
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
@@ -184,6 +185,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
num_kv_splits,
|
||||
self.scale,
|
||||
PAGE_SIZE,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
)
|
||||
|
||||
return o, lse
|
||||
|
||||
@@ -31,6 +31,7 @@ It supports page size >= 1.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@@ -74,6 +75,8 @@ def _fwd_kernel_stage1(
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
@@ -109,6 +112,8 @@ def _fwd_kernel_stage1(
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
ks = tl.load(k_scale)
|
||||
vs = tl.load(v_scale)
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
@@ -129,6 +134,8 @@ def _fwd_kernel_stage1(
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
if k.dtype.is_fp8():
|
||||
k = (k.to(tl.float32) * ks).to(q.dtype)
|
||||
qk = tl.sum(q[None, :] * k, 1)
|
||||
qk *= sm_scale
|
||||
|
||||
@@ -147,6 +154,8 @@ def _fwd_kernel_stage1(
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
if v.dtype.is_fp8():
|
||||
v = (v.to(tl.float32) * vs).to(q.dtype)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
@@ -194,6 +203,8 @@ def _decode_att_m_fwd(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
):
|
||||
BLOCK = 64 if not is_hip_ else 8
|
||||
|
||||
@@ -231,6 +242,8 @@ def _decode_att_m_fwd(
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
@@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_group_num: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
@@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
ks = tl.load(k_scale)
|
||||
vs = tl.load(v_scale)
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
@@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
if k.dtype.is_fp8():
|
||||
k = (k.to(tl.float32) * ks).to(q.dtype)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
if BLOCK_DPE > 0:
|
||||
offs_buf_kpe = (
|
||||
@@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
if kpe.dtype.is_fp8():
|
||||
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
|
||||
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||
qk *= sm_scale
|
||||
|
||||
@@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1(
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
if v.dtype.is_fp8():
|
||||
v = (v.to(tl.float32) * vs).to(q.dtype)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
@@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lk = k_buffer.shape[-1]
|
||||
@@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd(
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
q_head_num=head_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
@@ -609,6 +636,8 @@ def decode_attention_fwd_normal(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
@@ -621,6 +650,8 @@ def decode_attention_fwd_normal(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||
@@ -640,6 +671,8 @@ def decode_attention_fwd_grouped(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
@@ -652,6 +685,8 @@ def decode_attention_fwd_grouped(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||
@@ -671,8 +706,16 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
page_size=1,
|
||||
logit_cap=0.0,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
|
||||
if k_scale is None:
|
||||
k_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
|
||||
if v_scale is None:
|
||||
v_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
|
||||
|
||||
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||
|
||||
if kv_group_num == 1:
|
||||
@@ -690,6 +733,8 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
@@ -706,4 +751,6 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user