Convert benchmarks to ruff format (#18068)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 14:43:29 +01:00
committed by GitHub
parent b922c2ebd2
commit 009d9e7590
41 changed files with 3980 additions and 2938 deletions

View File

@@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul)
w8a8_block_fp8_matmul,
)
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1]
# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
def bench_fn(
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1
globals = {
@@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
def bench_int8(
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels."""
assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
"cutlass_i8_i8_bf16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
"cutlass_i8_i8_bf16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
bias),
"cutlass_i8_i8_bf16_scaled_mm_azp":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, None, bias),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, azp),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, azp, bias),
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
),
"cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16
),
"cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16, bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj
),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
),
}
timers = []
@@ -96,73 +101,65 @@ def bench_int8(
def bench_fp8(
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128),
device="cuda",
dtype=torch.float32)
block_scale_b = torch.rand((k // 128, n // 128),
device="cuda",
dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32)
block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32)
block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t()
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
print(m, k, n)
bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
"pytorch_fp8_fp8_fp16_scaled_mm":
lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
lambda: torch._scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.float16,
use_fast_accum=True),
"pytorch_fp8_fp8_bf16_scaled_mm":
lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
lambda: torch._scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True),
"cutlass_fp8_fp8_bf16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
"cutlass_fp8_fp8_fp16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
bias),
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)),
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
block_scale_b.t(), (128, 128)),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
block_scale_b_K_major, torch.float16),
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
),
"pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16
),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
),
"pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16
),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
),
"cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16
),
"cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16
),
"cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16, bias
),
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
),
}
timers = []
@@ -175,13 +172,15 @@ def bench_fp8(
return timers
def bench(dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
def bench(
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn:
@@ -195,27 +194,33 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print()
def run(dtype: torch.dtype,
MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
def run(
dtype: torch.dtype,
MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype,
m,
k,
n,
f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})",
bench_kernels=bench_kernels)
timers = bench(
dtype,
m,
k,
n,
f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})",
bench_kernels=bench_kernels,
)
print_timers(timers)
results.extend(timers)
return results
def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]],
base_description: str,
timestamp=None):
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]],
base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
@@ -226,8 +231,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
make_output(data, MKNs, f"square_bench-{args.dtype}")
@@ -285,7 +289,7 @@ def run_model_bench(args):
pkl.dump(all_data, f)
if __name__ == '__main__':
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
@@ -310,19 +314,21 @@ Benchmark Cutlass GEMM.
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']")
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']",
)
parser.add_argument(
"--kernels",
nargs="+",
type=str,
default=None,
help=
"Exact names of the kernels to benchmark. If not set, runs all kernels."
help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
)
subparsers = parser.add_subparsers(dest="cmd")
@@ -343,19 +349,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument(
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
)
model_parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()