135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
from vllm import _custom_ops as ops
|
||
|
|
from vllm.platforms import current_platform
|
||
|
|
from vllm.transformers_utils.config import get_config
|
||
|
|
from vllm.triton_utils import triton
|
||
|
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||
|
|
|
||
|
|
# Dimensions supported by the DSV3 specialized kernel
|
||
|
|
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||
|
|
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||
|
|
|
||
|
|
# Dimensions supported by the gpt-oss specialized kernel
|
||
|
|
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
|
||
|
|
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
|
||
|
|
|
||
|
|
|
||
|
|
def get_batch_size_range(max_batch_size):
|
||
|
|
return [2**x for x in range(14) if 2**x <= max_batch_size]
|
||
|
|
|
||
|
|
|
||
|
|
def get_model_params(config):
|
||
|
|
if config.architectures[0] in (
|
||
|
|
"DeepseekV2ForCausalLM",
|
||
|
|
"DeepseekV3ForCausalLM",
|
||
|
|
"DeepseekV32ForCausalLM",
|
||
|
|
):
|
||
|
|
num_experts = config.n_routed_experts
|
||
|
|
hidden_size = config.hidden_size
|
||
|
|
elif config.architectures[0] in ("GptOssForCausalLM",):
|
||
|
|
num_experts = config.num_local_experts
|
||
|
|
hidden_size = config.hidden_size
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported architecture: {config.architectures}")
|
||
|
|
return num_experts, hidden_size
|
||
|
|
|
||
|
|
|
||
|
|
def get_benchmark(model, max_batch_size, trust_remote_code):
|
||
|
|
@triton.testing.perf_report(
|
||
|
|
triton.testing.Benchmark(
|
||
|
|
x_names=["batch_size"],
|
||
|
|
x_vals=get_batch_size_range(max_batch_size),
|
||
|
|
x_log=False,
|
||
|
|
line_arg="provider",
|
||
|
|
line_vals=[
|
||
|
|
"torch",
|
||
|
|
"vllm",
|
||
|
|
],
|
||
|
|
line_names=["PyTorch", "vLLM"],
|
||
|
|
styles=([("blue", "-"), ("red", "-")]),
|
||
|
|
ylabel="TFLOPs",
|
||
|
|
plot_name=f"{model} router gemm throughput",
|
||
|
|
args={},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
def benchmark(batch_size, provider):
|
||
|
|
config = get_config(model=model, trust_remote_code=trust_remote_code)
|
||
|
|
num_experts, hidden_size = get_model_params(config)
|
||
|
|
|
||
|
|
mat_a = torch.randn(
|
||
|
|
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||
|
|
).contiguous()
|
||
|
|
mat_b = torch.randn(
|
||
|
|
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||
|
|
).contiguous()
|
||
|
|
bias = torch.randn(
|
||
|
|
num_experts, dtype=torch.bfloat16, device="cuda"
|
||
|
|
).contiguous()
|
||
|
|
|
||
|
|
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||
|
|
90
|
||
|
|
) or current_platform.is_device_capability_family(100)
|
||
|
|
allow_dsv3_router_gemm = (
|
||
|
|
is_hopper_or_blackwell
|
||
|
|
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
|
||
|
|
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
|
||
|
|
)
|
||
|
|
allow_gpt_oss_router_gemm = (
|
||
|
|
is_hopper_or_blackwell
|
||
|
|
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
|
||
|
|
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
|
||
|
|
)
|
||
|
|
|
||
|
|
has_bias = False
|
||
|
|
if allow_gpt_oss_router_gemm:
|
||
|
|
has_bias = True
|
||
|
|
|
||
|
|
quantiles = [0.5, 0.2, 0.8]
|
||
|
|
|
||
|
|
if provider == "torch":
|
||
|
|
|
||
|
|
def runner():
|
||
|
|
if has_bias:
|
||
|
|
F.linear(mat_a, mat_b, bias)
|
||
|
|
else:
|
||
|
|
F.linear(mat_a, mat_b)
|
||
|
|
elif provider == "vllm":
|
||
|
|
|
||
|
|
def runner():
|
||
|
|
if allow_dsv3_router_gemm:
|
||
|
|
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
|
||
|
|
elif allow_gpt_oss_router_gemm:
|
||
|
|
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
|
||
|
|
else:
|
||
|
|
raise ValueError("Unsupported router gemm")
|
||
|
|
|
||
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||
|
|
runner, quantiles=quantiles
|
||
|
|
)
|
||
|
|
|
||
|
|
def tflops(t_ms):
|
||
|
|
flops = 2 * batch_size * hidden_size * num_experts
|
||
|
|
return flops / (t_ms * 1e-3) / 1e12
|
||
|
|
|
||
|
|
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||
|
|
|
||
|
|
return benchmark
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
parser = FlexibleArgumentParser()
|
||
|
|
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
|
||
|
|
parser.add_argument("--max-batch-size", default=16, type=int)
|
||
|
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Get the benchmark function
|
||
|
|
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
|
||
|
|
# Run performance benchmark
|
||
|
|
benchmark.run(print_data=True)
|