Remove all cases of fmt: on/off (#26253)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 17:18:14 +01:00
committed by GitHub
parent 4e256cadc2
commit 557b2e961d
5 changed files with 216 additions and 156 deletions

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# fmt: off
# ruff: noqa: E501
import time
@@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import (
)
def benchmark_shape(m: int,
n: int,
k: int,
warmup: int = 100,
repeat: int = 10000,
verbose: bool = False) -> dict:
def benchmark_shape(
m: int,
n: int,
k: int,
warmup: int = 100,
repeat: int = 10000,
verbose: bool = False,
) -> dict:
"""Benchmark all implementations for a specific (m, n, k) shape."""
if verbose:
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
# Create test tensors
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Reference result in BF16
torch.cuda.synchronize()
@@ -49,34 +50,39 @@ def benchmark_shape(m: int,
# Pre-quantize A for all implementations
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
A, block_size[1], column_major_scales=True)
A, block_size[1], column_major_scales=True
)
# === DeepGEMM Implementation ===
def deepgemm_gemm():
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
(B_deepgemm, B_scale_deepgemm),
C_deepgemm)
fp8_gemm_nt(
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
)
return C_deepgemm
# === vLLM Triton Implementation ===
def vllm_triton_gemm():
return w8a8_triton_block_scaled_mm(A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,
block_size,
output_dtype=torch.bfloat16)
return w8a8_triton_block_scaled_mm(
A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,
block_size,
output_dtype=torch.bfloat16,
)
# === vLLM CUTLASS Implementation ===
def vllm_cutlass_gemm():
return ops.cutlass_scaled_mm(A_vllm_cutlass,
B_vllm.T,
scale_a=A_scale_vllm_cutlass,
scale_b=B_scale_vllm.T,
out_dtype=torch.bfloat16)
return ops.cutlass_scaled_mm(
A_vllm_cutlass,
B_vllm.T,
scale_a=A_scale_vllm_cutlass,
scale_b=B_scale_vllm.T,
out_dtype=torch.bfloat16,
)
# Run correctness check first
if verbose:
@@ -93,26 +99,23 @@ def benchmark_shape(m: int,
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
print("vLLM Triton vs DeepGEMM difference: "
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
print("vLLM CUTLASS vs DeepGEMM difference: "
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
print(
"vLLM Triton vs DeepGEMM difference: "
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
)
print(
"vLLM CUTLASS vs DeepGEMM difference: "
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
)
# Benchmark implementations
implementations = {
"DeepGEMM": deepgemm_gemm,
"vLLM Triton": vllm_triton_gemm,
"vLLM CUTLASS": vllm_cutlass_gemm
"vLLM CUTLASS": vllm_cutlass_gemm,
}
benchmark_results = {
"shape": {
"m": m,
"n": n,
"k": k
},
"implementations": {}
}
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
for name, func in implementations.items():
# Warmup
@@ -140,38 +143,36 @@ def benchmark_shape(m: int,
"tflops": tflops,
"gb_s": gb_s,
"diff": {
"DeepGEMM":
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
"Reference":
deepgemm_diff if name == "DeepGEMM" else
(vllm_triton_diff
if name == "vLLM Triton" else vllm_cutlass_diff)
}
"DeepGEMM": 0.0
if name == "DeepGEMM"
else calc_diff(func(), C_deepgemm),
"Reference": deepgemm_diff
if name == "DeepGEMM"
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
},
}
if verbose:
print(
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
)
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
# Calculate speedups
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
for name, data in benchmark_results["implementations"].items():
if name != "DeepGEMM":
speedup = baseline / data["time_ms"]
benchmark_results["implementations"][name][
"speedup_vs_deepgemm"] = speedup
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
if verbose:
print(f"DeepGEMM is {1/speedup:.2f}x "
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
print(
f"DeepGEMM is {1 / speedup:.2f}x "
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
)
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
"time_ms"]
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
"time_ms"]
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
benchmark_results["implementations"]["vLLM CUTLASS"][
"speedup_vs_triton"] = cutlass_vs_triton
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
cutlass_vs_triton
)
if verbose:
print(
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
@@ -183,8 +184,7 @@ def benchmark_shape(m: int,
def format_table_row(values, widths):
"""Format a row with specified column widths."""
return "| " + " | ".join(f"{val:{w}}"
for val, w in zip(values, widths)) + " |"
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
def print_table(headers, rows, title=None):
@@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False):
for result in all_results:
shape = result["shape"]
impl_data = result["implementations"]["DeepGEMM"]
deepgemm_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
])
deepgemm_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
]
)
print_table(deepgemm_headers,
deepgemm_rows,
title="DeepGEMM Implementation:")
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
# Print vLLM Triton table
triton_headers = [
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
]
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
triton_rows = []
for result in all_results:
shape = result["shape"]
impl_data = result["implementations"]["vLLM Triton"]
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
triton_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
format_speedup(speedup)
])
triton_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(speedup),
]
)
print_table(triton_headers,
triton_rows,
title="vLLM Triton Implementation:")
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
# Print vLLM CUTLASS table
cutlass_headers = [
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
"vs Triton"
"m",
"n",
"k",
"Time (μs)",
"TFLOPS",
"GB/s",
"vs DeepGEMM",
"vs Triton",
]
cutlass_rows = []
for result in all_results:
@@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False):
impl_data = result["implementations"]["vLLM CUTLASS"]
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
cutlass_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
format_speedup(vs_deepgemm),
format_speedup(vs_triton)
])
cutlass_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(vs_deepgemm),
format_speedup(vs_triton),
]
)
print_table(cutlass_headers,
cutlass_rows,
title="vLLM CUTLASS Implementation:")
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
# Calculate and print averages
print("\n===== AVERAGE PERFORMANCE =====")
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
avg_metrics = {
impl: {
"tflops": 0,
"gb_s": 0,
"time_ms": 0
}
for impl in implementations
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
}
for result in all_results:
@@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False):
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
avg_rows.append([
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
])
avg_rows.append(
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
)
print_table(avg_headers, avg_rows)
@@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False):
avg_speedups = {
"DeepGEMM vs vLLM Triton": 0,
"DeepGEMM vs vLLM CUTLASS": 0,
"vLLM CUTLASS vs vLLM Triton": 0
"vLLM CUTLASS vs vLLM Triton": 0,
}
for result in all_results:
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
"time_ms"]
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
avg_speedups[
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
avg_speedups[
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
avg_speedups[
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
vllm_triton_time / vllm_cutlass_time
)
print("\n===== AVERAGE SPEEDUPS =====")
speedup_headers = ["Comparison", "Speedup"]
@@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
for result in all_results:
for impl in implementations:
avg_diff[impl] += result["implementations"][impl]["diff"][
"Reference"]
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
diff_headers = ["Implementation", "Avg Diff vs Reference"]
diff_rows = []