[Attention][Perf][Kernel] Replace torch.cat with vectorized CUDA kernel MLA query concat - DeepSeek-V3.2 (#34917)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
2b28b9b269
commit
580864d81e
98
benchmarks/kernels/bench_concat_mla_q.py
Normal file
98
benchmarks/kernels/bench_concat_mla_q.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
# DeepSeek V3 dimensions
|
||||
NOPE_DIM = 512
|
||||
ROPE_DIM = 64
|
||||
NUM_HEADS = 128
|
||||
|
||||
NUM_TOKENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
|
||||
def get_configs():
|
||||
return NUM_TOKENS
|
||||
|
||||
|
||||
def make_inputs(num_tokens, dtype):
|
||||
"""Create inputs matching the real code path.
|
||||
|
||||
Args:
|
||||
contiguous_nope: If False, simulate the transposed BMM output
|
||||
(non-contiguous nope with stride pattern from
|
||||
[N,B,L].transpose(0,1)).
|
||||
"""
|
||||
# Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L]
|
||||
raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, device="cuda")
|
||||
ql_nope = raw.transpose(0, 1)
|
||||
|
||||
q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, device="cuda")
|
||||
return ql_nope, q_pe
|
||||
|
||||
|
||||
# ---- Non-contiguous nope benchmark (real code path) ----
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=get_configs(),
|
||||
line_arg="provider",
|
||||
line_vals=["torch_cat", "concat_mla_q"],
|
||||
line_names=["torch.cat", "concat_mla_q (v8)"],
|
||||
styles=[("blue", "--"), ("green", "-")],
|
||||
ylabel="Latency (us)",
|
||||
plot_name="concat_mla_q-transposed",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def bench_transposed(num_tokens, provider):
|
||||
dtype = torch.bfloat16
|
||||
ql_nope, q_pe = make_inputs(num_tokens, dtype)
|
||||
|
||||
q_out = torch.empty(
|
||||
num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch_cat":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.cat((ql_nope, q_pe), dim=-1), quantiles=quantiles, rep=500
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), quantiles=quantiles, rep=500
|
||||
)
|
||||
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # us
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark concat_mla_q vs torch.cat")
|
||||
parser.add_argument(
|
||||
"--save-path", type=str, default=None, help="Path to save benchmark results"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("CONCAT MLA Q KERNEL BENCHMARKS")
|
||||
print("=" * 70)
|
||||
print(f"Dimensions: nope={NOPE_DIM}, rope={ROPE_DIM}, heads={NUM_HEADS}")
|
||||
print(
|
||||
f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = "
|
||||
f"{(NOPE_DIM + ROPE_DIM) * 2} bytes"
|
||||
)
|
||||
print(f"num_tokens (decode=batch_size, prefill=chunk_size): {NUM_TOKENS}")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n--- Non-contiguous nope inputs (transposed BMM output) ---")
|
||||
bench_transposed.run(print_data=True, save_path=args.save_path)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Benchmarking complete!")
|
||||
print("=" * 70)
|
||||
Reference in New Issue
Block a user