[Perf] Improve default triton fused moe configs (#34846)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
278
benchmarks/kernels/benchmark_moe_defaults.py
Normal file
278
benchmarks/kernels/benchmark_moe_defaults.py
Normal file
@@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark comparing old vs new default fused MoE configs.
|
||||
|
||||
Runs the triton fused_moe kernel with three configurations for each scenario:
|
||||
1. Tuned config (from JSON file, if available) — the target to match
|
||||
2. Old default (the hardcoded defaults before this change)
|
||||
3. New default (the improved defaults)
|
||||
|
||||
Usage:
|
||||
python benchmarks/kernels/benchmark_moe_defaults.py
|
||||
|
||||
Produces a table showing kernel time (us) and speedup of new vs old defaults.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts,
|
||||
get_default_config,
|
||||
get_moe_configs,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def old_default_config(M, E, N, K, topk, dtype=None, block_shape=None):
|
||||
"""The original defaults before https://github.com/vllm-project/vllm/pull/34846,
|
||||
for comparison."""
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
return {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_shape[0],
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"SPLIT_K": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
||||
}
|
||||
elif M <= E:
|
||||
return {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config,
|
||||
M,
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8=False,
|
||||
block_shape=None,
|
||||
num_iters=100,
|
||||
):
|
||||
"""Time a single kernel config. Returns kernel time in microseconds."""
|
||||
init_dtype = torch.float16 if use_fp8 else dtype
|
||||
|
||||
a = torch.randn(M, K, device="cuda", dtype=init_dtype) / 10
|
||||
w1 = torch.randn(E, 2 * N, K, device="cuda", dtype=init_dtype) / 10
|
||||
w2 = torch.randn(E, K, N, device="cuda", dtype=init_dtype) / 10
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_fp8:
|
||||
if block_shape is not None:
|
||||
bsn, bsk = block_shape
|
||||
n_tiles_w1 = triton.cdiv(2 * N, bsn)
|
||||
k_tiles_w1 = triton.cdiv(K, bsk)
|
||||
n_tiles_w2 = triton.cdiv(K, bsn)
|
||||
k_tiles_w2 = triton.cdiv(N, bsk)
|
||||
w1_scale = torch.rand(
|
||||
E, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
E, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
w1_scale = torch.rand(E, device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.rand(E, device="cuda", dtype=torch.float32)
|
||||
a1_scale = torch.rand(1, device="cuda", dtype=torch.float32)
|
||||
a2_scale = torch.rand(1, device="cuda", dtype=torch.float32)
|
||||
# Only weights are stored in fp8; activations stay in bf16/fp16
|
||||
# and get dynamically quantized inside the kernel.
|
||||
w1 = w1.to(FP8_DTYPE)
|
||||
w2 = w2.to(FP8_DTYPE)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype=torch.float8_e4m3fn if use_fp8 else None,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
gating = torch.randn(M, E, device="cuda", dtype=torch.float32)
|
||||
|
||||
# Warmup
|
||||
for _ in range(20):
|
||||
with override_config(config):
|
||||
topk_weights, topk_ids, _ = fused_topk(a, gating, topk, renormalize=True)
|
||||
fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start.record()
|
||||
for _ in range(num_iters):
|
||||
with override_config(config):
|
||||
topk_weights, topk_ids, _ = fused_topk(a, gating, topk, renormalize=True)
|
||||
fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
return start.elapsed_time(end) / num_iters * 1000 # ms -> us
|
||||
|
||||
|
||||
# Model configurations: (name, E, N, K, topk, dtype_str, use_fp8, block_shape)
|
||||
# N = moe_intermediate_size // tp_size (the value used in config file lookup)
|
||||
MODELS = [
|
||||
# --- Few experts ---
|
||||
("Mixtral bf16", 8, 7168, 4096, 2, None, False, None),
|
||||
("Mixtral fp8", 8, 7168, 4096, 2, "fp8_w8a8", True, None),
|
||||
# --- Many experts: real model shapes at tp=1 ---
|
||||
# Qwen2-MoE-57B: E=60, topk=4, N=1408, K=2048
|
||||
("Qwen2-MoE bf16", 60, 1408, 2048, 4, None, False, None),
|
||||
# DeepSeek-V2: E=64, topk=6, N=1407, K=4096
|
||||
# (use 1408 to avoid odd alignment; real model is 1407)
|
||||
("DeepSeek-V2 bf16", 64, 1408, 4096, 6, None, False, None),
|
||||
# OLMoE-7B: E=64, topk=8, N=2048, K=2048
|
||||
("OLMoE bf16", 64, 2048, 2048, 8, None, False, None),
|
||||
# GLM-4-100B-A10B: E=128, topk=8, N=1408, K=4096
|
||||
("GLM-4-MoE bf16", 128, 1408, 4096, 8, None, False, None),
|
||||
# Qwen3-30B-A3B: E=128, topk=8, N=768, K=2048
|
||||
("Qwen3-MoE bf16", 128, 768, 2048, 8, None, False, None),
|
||||
# DeepSeek-V3 / MiMo-V2-Flash: E=256, topk=8, N=2048, K=7168
|
||||
("DeepSeek-V3 bf16", 256, 2048, 7168, 8, None, False, None),
|
||||
# Qwen3.5-70B-A22B (Qwen3-Next): E=512, topk=10, N=512, K=2048
|
||||
("Qwen3-Next bf16", 512, 512, 2048, 10, None, False, None),
|
||||
# E=128 N=1856 bf16
|
||||
("E128 N1856 bf16", 128, 1856, 4096, 8, None, False, None),
|
||||
# E=256 N=512 bf16 (DS-V3 tp=4)
|
||||
("DS-V3 tp4 bf16", 256, 512, 7168, 8, None, False, None),
|
||||
# E=512 N=512 bf16 (Qwen3-Next tp=1)
|
||||
("Qwen3-Next bf16", 512, 512, 2048, 10, None, False, None),
|
||||
# E=512 N=256 bf16 (Qwen3-Next tp=2)
|
||||
("Qwen3-Next tp2", 512, 256, 2048, 10, None, False, None),
|
||||
# --- FP8 block quant (many experts) ---
|
||||
# DS-V3 tp=4: E=256, N=512, fp8 block
|
||||
("DS-V3 tp4 fp8blk", 256, 512, 7168, 8, "fp8_w8a8", True, [128, 128]),
|
||||
# DS-V3 tp=8: E=256, N=256, fp8 block
|
||||
("DS-V3 tp8 fp8blk", 256, 256, 7168, 8, "fp8_w8a8", True, [128, 128]),
|
||||
# Qwen3-Next tp=2 fp8 block
|
||||
("Qwen3-Next tp2 fp8blk", 512, 256, 2048, 10, "fp8_w8a8", True, [128, 128]),
|
||||
]
|
||||
|
||||
BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
|
||||
def main():
|
||||
set_random_seed(0)
|
||||
torch.set_default_device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
for name, E, N, K, topk, dtype_str, use_fp8, block_shape in MODELS:
|
||||
print(f"\n{'=' * 90}")
|
||||
print(f" {name} (E={E}, N={N}, K={K}, topk={topk})")
|
||||
print(f"{'=' * 90}")
|
||||
|
||||
# Try to load tuned config
|
||||
block_n = block_shape[0] if block_shape else None
|
||||
block_k = block_shape[1] if block_shape else None
|
||||
tuned = get_moe_configs(E, N, dtype_str, block_n, block_k)
|
||||
has_tuned = tuned is not None
|
||||
print(f" Tuned config available: {has_tuned}")
|
||||
|
||||
hdr = (
|
||||
f"{'Batch':>6} | {'Tuned (us)':>11} | {'Old (us)':>11} | "
|
||||
f"{'New (us)':>11} | {'New/Old':>8} | {'New/Tuned':>10}"
|
||||
)
|
||||
print(f" {hdr}")
|
||||
print(f" {'-' * len(hdr)}")
|
||||
|
||||
for M in BATCH_SIZES:
|
||||
old_cfg = old_default_config(M, E, N, K, topk, dtype_str, block_shape)
|
||||
new_cfg = get_default_config(M, E, N, K, topk, dtype_str, block_shape)
|
||||
|
||||
if has_tuned:
|
||||
tuned_cfg = tuned[min(tuned.keys(), key=lambda x: abs(x - M))]
|
||||
t_tuned = benchmark_config(
|
||||
tuned_cfg,
|
||||
M,
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8=use_fp8,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
t_tuned = None
|
||||
|
||||
t_old = benchmark_config(
|
||||
old_cfg,
|
||||
M,
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8=use_fp8,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
t_new = benchmark_config(
|
||||
new_cfg,
|
||||
M,
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8=use_fp8,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
ratio_new_old = t_new / t_old
|
||||
tuned_str = f"{t_tuned:11.2f}" if t_tuned else f"{'N/A':>11}"
|
||||
ratio_tuned = f"{t_new / t_tuned:10.2f}x" if t_tuned else f"{'N/A':>10}"
|
||||
# flag regressions where new default is >5% slower than old
|
||||
marker = " <--" if ratio_new_old > 1.05 else ""
|
||||
|
||||
print(
|
||||
f" {M:>6} | {tuned_str} | {t_old:11.2f} | {t_new:11.2f} "
|
||||
f"| {ratio_new_old:7.2f}x | {ratio_tuned}{marker}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1233,28 +1233,31 @@ def get_default_config(
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if vllm_is_batch_invariant():
|
||||
config = {
|
||||
return {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
# num_stages can cause triton.runtime.errors.OutOfResources on ROCm.
|
||||
num_stages_rocm = 2
|
||||
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
# num_stages=3 can cause triton.runtime.errors.OutOfResources
|
||||
# on ROCm, set it to 2 instead.
|
||||
# Block-wise quant: tile sizes are constrained by block_shape.
|
||||
# Use a small M tile for decode-like batches where tokens are
|
||||
# spread thin across experts. Larger batches benefit from
|
||||
# GROUP_SIZE_M > 1 because the per-block scales add memory
|
||||
# traffic that benefits from L2 tile reuse.
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_M": 16 if M <= 64 else 64,
|
||||
"BLOCK_SIZE_N": block_shape[0],
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"GROUP_SIZE_M": 1 if M <= 16 else 32,
|
||||
"SPLIT_K": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
||||
"num_stages": 3 if not current_platform.is_rocm() else num_stages_rocm,
|
||||
}
|
||||
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
|
||||
# moe wna16 kernels
|
||||
@@ -1270,21 +1273,52 @@ def get_default_config(
|
||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||
else:
|
||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||
elif M <= E:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
else:
|
||||
# General defaults for bf16/fp16 and fp8 per-tensor.
|
||||
# Tile sizes scale with batch: small batches are memory-bound
|
||||
# (favor tall-K tiles), large batches are compute-bound (favor
|
||||
# large M/N tiles with more warps).
|
||||
if M <= 32:
|
||||
block_m = 16
|
||||
elif M <= 96:
|
||||
block_m = 32
|
||||
elif M <= 512:
|
||||
block_m = 64
|
||||
else:
|
||||
block_m = 128
|
||||
|
||||
block_n = 64 if M <= 64 else 128
|
||||
|
||||
# Small batches benefit from longer reduction (larger K tile),
|
||||
# while large batches prefer more output parallelism.
|
||||
# FP8 elements are half-width so larger K tiles are always cheap.
|
||||
block_k = 128 if dtype == "fp8_w8a8" or M <= 64 else 64
|
||||
|
||||
# Grouping adjacent M-blocks lets them share weight tiles in L2.
|
||||
# Only helps when there are enough M-blocks per expert to group;
|
||||
# with many experts each one sees few tokens so grouping is useless.
|
||||
tokens_per_expert = M // max(E, 1)
|
||||
group_m = 16 if tokens_per_expert > 128 else 1
|
||||
|
||||
# Large batches have enough blocks to saturate the GPU, so we
|
||||
# use more warps per block to increase arithmetic intensity.
|
||||
num_warps = 4 if M <= 128 else 8
|
||||
|
||||
if current_platform.is_rocm():
|
||||
num_stages = num_stages_rocm
|
||||
elif M <= 32:
|
||||
num_stages = 4
|
||||
else:
|
||||
num_stages = 3
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_m,
|
||||
"SPLIT_K": 1,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user