2025-07-18 11:32:22 +08:00
|
|
|
import copy
|
2025-02-25 22:52:41 +08:00
|
|
|
import random
|
2025-07-18 11:32:22 +08:00
|
|
|
import time
|
2025-02-25 22:52:41 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
import deep_gemm
|
2025-07-18 11:32:22 +08:00
|
|
|
from deep_gemm.testing import (
|
|
|
|
|
bench, bench_kineto,
|
|
|
|
|
calc_diff, count_bytes
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from generators import (
|
2025-09-25 16:19:07 +08:00
|
|
|
KernelType, get_arch_major, get_ue8m0_usage,
|
2025-07-18 11:32:22 +08:00
|
|
|
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
|
|
|
|
|
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
|
|
|
|
|
)
|
2025-05-14 14:47:58 +08:00
|
|
|
|
|
|
|
|
|
2025-02-25 22:52:41 +08:00
|
|
|
def test_gemm() -> None:
|
|
|
|
|
print('Testing GEMM:')
|
2025-09-25 16:19:07 +08:00
|
|
|
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn):
|
2025-07-18 11:32:22 +08:00
|
|
|
major_opt = 'N' if major_a.is_k_major() else 'T'
|
|
|
|
|
major_opt += 'T' if major_b.is_k_major() else 'N'
|
|
|
|
|
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
|
|
|
|
acc_opt = f'acc={int(accumulate)}'
|
|
|
|
|
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
|
|
|
|
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
|
|
|
|
disable_ue8m0_cast = not use_ue8m0
|
2025-09-25 16:19:07 +08:00
|
|
|
recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None
|
2025-07-18 11:32:22 +08:00
|
|
|
|
|
|
|
|
for test_alias in (False, True):
|
2025-09-25 16:19:07 +08:00
|
|
|
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
2025-07-18 11:32:22 +08:00
|
|
|
func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}'
|
|
|
|
|
if test_alias:
|
|
|
|
|
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
|
|
|
|
|
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
|
|
|
|
|
assert a[0].is_contiguous() and b[0].is_contiguous()
|
2025-09-25 16:19:07 +08:00
|
|
|
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe)
|
2025-07-18 11:32:22 +08:00
|
|
|
diff = calc_diff(d, ref_d)
|
|
|
|
|
assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
|
|
|
|
|
f'{diff:.5f}, alias={test_alias}')
|
|
|
|
|
|
2025-09-25 16:19:07 +08:00
|
|
|
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
|
|
|
|
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe),
|
|
|
|
|
'fp8_gemm', suppress_kineto_output=True)
|
|
|
|
|
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
|
|
|
|
|
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
|
|
|
|
f'{t * 1e6:4.0f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
|
|
|
|
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
|
|
|
|
|
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
|
2025-02-25 22:52:41 +08:00
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_m_grouped_gemm_contiguous() -> None:
|
2025-07-18 11:32:22 +08:00
|
|
|
print('Testing m-grouped contiguous GEMM:')
|
|
|
|
|
|
2025-09-25 16:19:07 +08:00
|
|
|
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn):
|
2025-07-18 11:32:22 +08:00
|
|
|
major_opt = 'N' if major_a.is_k_major() else 'T'
|
|
|
|
|
major_opt += 'T' if major_b.is_k_major() else 'N'
|
|
|
|
|
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
|
|
|
|
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
|
|
|
|
disable_ue8m0_cast = not use_ue8m0
|
|
|
|
|
|
|
|
|
|
for test_alias in (False, True):
|
|
|
|
|
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0)
|
|
|
|
|
func_name = f"m_grouped_fp8_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
|
|
|
|
|
if test_alias:
|
|
|
|
|
assert major_a.is_k_major()
|
|
|
|
|
b = b if major_b.is_k_major() else (b[0].mT, b[1].mT)
|
|
|
|
|
assert a[0].is_contiguous() and b[0].is_contiguous()
|
|
|
|
|
getattr(deep_gemm, func_name)(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
|
|
|
|
|
d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d)
|
|
|
|
|
diff = calc_diff(d, ref_d)
|
|
|
|
|
assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}'
|
|
|
|
|
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0)
|
2025-02-25 22:52:41 +08:00
|
|
|
|
|
|
|
|
# noinspection PyShadowingNames
|
|
|
|
|
def test_func():
|
2025-07-18 11:32:22 +08:00
|
|
|
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
|
2025-02-25 22:52:41 +08:00
|
|
|
|
|
|
|
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
2025-09-25 16:19:07 +08:00
|
|
|
print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}): '
|
2025-07-18 11:32:22 +08:00
|
|
|
f'{t * 1e6:4.0f} us | '
|
|
|
|
|
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
|
|
|
|
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
|
2025-02-25 22:52:41 +08:00
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_m_grouped_gemm_masked() -> None:
|
2025-07-18 11:32:22 +08:00
|
|
|
print('Testing m-grouped masked GEMM:')
|
2025-02-25 22:52:41 +08:00
|
|
|
|
2025-07-18 11:32:22 +08:00
|
|
|
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
2025-09-25 16:19:07 +08:00
|
|
|
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn):
|
2025-07-18 11:32:22 +08:00
|
|
|
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
|
|
|
|
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
|
|
|
|
disable_ue8m0_cast = not use_ue8m0
|
2025-05-14 14:47:58 +08:00
|
|
|
|
2025-07-18 11:32:22 +08:00
|
|
|
# Test correctness
|
|
|
|
|
for i in range(10):
|
|
|
|
|
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
|
2025-08-15 18:32:35 +08:00
|
|
|
deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
2025-07-18 11:32:22 +08:00
|
|
|
for j in range(num_groups):
|
|
|
|
|
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
|
2025-08-14 16:47:57 +08:00
|
|
|
assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
|
2025-05-14 14:47:58 +08:00
|
|
|
|
2025-07-18 11:32:22 +08:00
|
|
|
# Construct full cases
|
|
|
|
|
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
|
2025-05-14 14:47:58 +08:00
|
|
|
|
2025-07-18 11:32:22 +08:00
|
|
|
# noinspection PyShadowingNames
|
|
|
|
|
def test_func():
|
2025-08-15 18:32:35 +08:00
|
|
|
deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
2025-05-14 14:47:58 +08:00
|
|
|
|
2025-07-18 11:32:22 +08:00
|
|
|
# Test performance with fixed shapes
|
|
|
|
|
valid_m = masked_m.sum().item()
|
|
|
|
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
|
|
|
|
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): '
|
|
|
|
|
f'{t * 1e6:4.0f} us | '
|
|
|
|
|
f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | '
|
|
|
|
|
f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')
|
2025-05-14 14:47:58 +08:00
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
2025-07-18 11:32:22 +08:00
|
|
|
def test_k_grouped_gemm_contiguous() -> None:
|
|
|
|
|
print('Testing k-grouped contiguous GEMM:')
|
|
|
|
|
|
2025-09-25 16:19:07 +08:00
|
|
|
k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \
|
|
|
|
|
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous
|
|
|
|
|
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous():
|
2025-07-18 11:32:22 +08:00
|
|
|
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
|
|
|
|
|
|
|
|
|
|
for test_empty_groups in (False, True):
|
|
|
|
|
new_ks = copy.deepcopy(ks)
|
2025-09-25 16:19:07 +08:00
|
|
|
if test_empty_groups and len(ks) > 1:
|
2025-07-18 11:32:22 +08:00
|
|
|
new_ks[random.randint(0, num_groups - 1)] = 0
|
2025-09-25 16:19:07 +08:00
|
|
|
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0)
|
2025-07-18 11:32:22 +08:00
|
|
|
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
|
2025-09-25 16:19:07 +08:00
|
|
|
k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c)
|
|
|
|
|
|
|
|
|
|
do_check = True
|
|
|
|
|
if do_check:
|
|
|
|
|
diff = calc_diff(d, ref_d)
|
|
|
|
|
assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}'
|
2025-07-18 11:32:22 +08:00
|
|
|
|
|
|
|
|
# Test performance
|
2025-09-25 16:19:07 +08:00
|
|
|
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0)
|
2025-07-18 11:32:22 +08:00
|
|
|
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
|
|
|
|
|
|
|
|
|
|
# noinspection PyShadowingNames
|
|
|
|
|
def test_func():
|
2025-09-25 16:19:07 +08:00
|
|
|
k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c)
|
2025-07-18 11:32:22 +08:00
|
|
|
|
|
|
|
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
|
|
|
|
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '
|
|
|
|
|
f'{t * 1e6:4.0f} us | '
|
|
|
|
|
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
|
|
|
|
f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s')
|
2025-05-14 14:47:58 +08:00
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
2025-02-25 22:52:41 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
random.seed(0)
|
|
|
|
|
|
|
|
|
|
print('Library path:')
|
|
|
|
|
print(f' > {deep_gemm.__path__}\n')
|
|
|
|
|
|
|
|
|
|
test_gemm()
|
|
|
|
|
test_m_grouped_gemm_contiguous()
|
|
|
|
|
test_m_grouped_gemm_masked()
|
2025-07-18 11:32:22 +08:00
|
|
|
test_k_grouped_gemm_contiguous()
|