[Kernel] Support decode context parallelism on Blackwell with CUTLASS MLA (#24385)
Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -14,14 +15,20 @@ from vllm.triton_utils import triton
|
||||
def cal_diff(x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False) -> None:
|
||||
use_fp8: bool = False,
|
||||
diff_threshold: Optional[float] = None) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||
(x * x + y * y).sum().item(), 1e-12)
|
||||
if (use_fp8):
|
||||
assert cos_diff < 1e-4
|
||||
if diff_threshold is not None:
|
||||
# directly compare the cos_diff with the threshold
|
||||
assert cos_diff < diff_threshold
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
# use the default threshold
|
||||
if (use_fp8):
|
||||
assert cos_diff < 1e-4
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
CUTLASS_MLA_UNSUPPORTED_REASON = \
|
||||
@@ -118,11 +125,13 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
dtype=torch.uint8)
|
||||
|
||||
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
|
||||
|
||||
ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat,
|
||||
cache_seqlens, block_table, workspace,
|
||||
scale, 1)
|
||||
return out_ans[:, :h_q].contiguous()
|
||||
output_lse = torch.empty((b, MAX_HEADS),
|
||||
dtype=torch.float32,
|
||||
device=q_nope.device)
|
||||
ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe,
|
||||
kv_cache_flat, cache_seqlens, block_table,
|
||||
workspace, scale, 1)
|
||||
return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous()
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
@@ -165,11 +174,14 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
lse[i] = lse_i
|
||||
return out, lse
|
||||
|
||||
out_cutlass = cutlass_mla()
|
||||
out_cutlass, lse_cutlass = cutlass_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
# Extract the single token (s_q=1) slice to match cutlass output shape
|
||||
out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv]
|
||||
lse_torch_slice = lse_torch[:, 0, :] # [b, h_q]
|
||||
cal_diff(out_cutlass, out_torch_slice, "out", use_fp8)
|
||||
# lse has larger numerical error, so use a larger threshold
|
||||
cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3)
|
||||
|
||||
t = triton.testing.do_bench(cutlass_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
|
||||
Reference in New Issue
Block a user