Fix w8a8 benchmark and add Llama-3-8B (#5562)

This commit is contained in:
Cody Yu
2024-06-16 23:48:06 -07:00
committed by GitHub
parent 845a3f26f9
commit e2b85cf86a
2 changed files with 19 additions and 8 deletions

View File

@@ -46,7 +46,7 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
# impl
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return torch.mm(a, b)
@@ -115,7 +115,7 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_i8_impl,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
# cutlass impl
@@ -136,6 +136,13 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers = []
# pytorch impl w. bf16
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
@@ -160,14 +167,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# cutlass impl: bf16 output
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.bfloat16, label, sub_label, cutlass_impl,
"cutlass_fp8_fp8_bf16_scaled_mm"))
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
# cutlass impl: fp16 output
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.float16, label, sub_label, cutlass_impl,
"cutlass_fp8_fp8_fp16_scaled_mm"))
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
return timers