Convert benchmarks to ruff format (#18068)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul)
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
@@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1]
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
**kwargs) -> TMeasurement:
|
||||
def bench_fn(
|
||||
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
|
||||
) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
@@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
|
||||
|
||||
def bench_int8(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None,
|
||||
) -> Iterable[TMeasurement]:
|
||||
"""Benchmark INT8-based kernels."""
|
||||
assert dtype == torch.int8
|
||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||
azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
|
||||
azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
|
||||
|
||||
bench_fns = {
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||
"cutlass_i8_i8_bf16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, None, bias),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, azp),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, azp, bias),
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
|
||||
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.bfloat16
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, bias
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
|
||||
),
|
||||
}
|
||||
|
||||
timers = []
|
||||
@@ -96,73 +101,65 @@ def bench_int8(
|
||||
|
||||
|
||||
def bench_fp8(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None,
|
||||
) -> Iterable[TMeasurement]:
|
||||
"""Benchmark FP8-based kernels."""
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
a_cont = a.contiguous()
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
block_scale_a = torch.rand((m, k // 128),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
block_scale_b = torch.rand((k // 128, n // 128),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32)
|
||||
block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32)
|
||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
print(m, k, n)
|
||||
|
||||
bench_fns = {
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm":
|
||||
lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.float16),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
|
||||
lambda: torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm":
|
||||
lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
|
||||
lambda: torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
|
||||
bias.to(dtype=torch.float16)),
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
|
||||
block_scale_b.t(), (128, 128)),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
|
||||
block_scale_b_K_major, torch.float16),
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
|
||||
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
|
||||
),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.float16
|
||||
),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
|
||||
),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16
|
||||
),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
|
||||
),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.bfloat16
|
||||
),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.float16
|
||||
),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, bias
|
||||
),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
|
||||
),
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
|
||||
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
|
||||
),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
|
||||
),
|
||||
}
|
||||
|
||||
timers = []
|
||||
@@ -175,13 +172,15 @@ def bench_fp8(
|
||||
return timers
|
||||
|
||||
|
||||
def bench(dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
def bench(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None,
|
||||
) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
@@ -195,27 +194,33 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
def run(
|
||||
dtype: torch.dtype,
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
bench_kernels: Optional[list[str]] = None,
|
||||
) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
bench_kernels=bench_kernels)
|
||||
timers = bench(
|
||||
dtype,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
bench_kernels=bench_kernels,
|
||||
)
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
return results
|
||||
|
||||
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
def make_output(
|
||||
data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
):
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
@@ -226,8 +231,7 @@ def make_output(data: Iterable[TMeasurement],
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
@@ -285,7 +289,7 @@ def run_model_bench(args):
|
||||
pkl.dump(all_data, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
@@ -310,19 +314,21 @@ Benchmark Cutlass GEMM.
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kernels",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Exact names of the kernels to benchmark. If not set, runs all kernels."
|
||||
help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
@@ -343,19 +349,19 @@ Benchmark Cutlass GEMM.
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument("--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys())
|
||||
model_parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
model_parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
model_parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user