[Perf] Optimize moe_align_block_size CUDA kernel (#19572)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
159
benchmarks/kernels/benchmark_moe_align_block_size.py
Normal file
159
benchmarks/kernels/benchmark_moe_align_block_size.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size_triton,
|
||||
)
|
||||
|
||||
|
||||
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[
|
||||
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||
for _ in range(num_tokens)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
|
||||
"""
|
||||
Verifies vllm vs. Triton
|
||||
"""
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
|
||||
# 1. malloc space for triton and vllm
|
||||
# malloc enough space (max_num_tokens_padded) for the sorted ids
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids_triton = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
|
||||
expert_ids_triton = torch.zeros(
|
||||
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
|
||||
sorted_ids_vllm.fill_(topk_ids.numel())
|
||||
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
|
||||
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
|
||||
|
||||
# 2. run implementations
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_triton,
|
||||
expert_ids_triton,
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
|
||||
# 3. compare results
|
||||
if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose(
|
||||
num_tokens_post_pad_triton, num_tokens_post_pad_vllm
|
||||
):
|
||||
print("✅ Triton and VLLM implementations match.")
|
||||
else:
|
||||
print("❌ Triton and VLLM implementations DO NOT match.")
|
||||
print("Triton expert_ids:", expert_ids_triton)
|
||||
print("VLLM expert_ids:", expert_ids_vllm)
|
||||
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
||||
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
|
||||
|
||||
|
||||
# test configurations
|
||||
num_tokens_range = [1, 16, 256, 4096]
|
||||
num_experts_range = [16, 64, 224, 256, 280, 512]
|
||||
topk_range = [1, 2, 8]
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "triton"], # "triton"
|
||||
line_names=["VLLM", "Triton"], # "Triton"
|
||||
plot_name="moe-align-block-size-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, num_experts, topk, provider):
|
||||
"""Benchmark function for Triton."""
|
||||
block_size = 256
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
|
||||
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num_experts",
|
||||
type=int,
|
||||
default=64,
|
||||
choices=[8, 16, 32, 64, 128, 256],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=[2, 4, 8],
|
||||
help="Top-k value for correctness check.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Running correctness check...")
|
||||
check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
Reference in New Issue
Block a user