65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
|
|
import random
|
||
|
|
import torch
|
||
|
|
from typing import Tuple
|
||
|
|
|
||
|
|
import deep_gemm
|
||
|
|
from deep_gemm.testing import bench_kineto, calc_diff, count_bytes
|
||
|
|
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8
|
||
|
|
|
||
|
|
from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB
|
||
|
|
|
||
|
|
|
||
|
|
def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]):
|
||
|
|
left, mid, right = head_splits
|
||
|
|
m, n = d.shape
|
||
|
|
assert n % (left + right) == 0
|
||
|
|
num_heads = n // (left + right)
|
||
|
|
|
||
|
|
# Split and insert padding tensor
|
||
|
|
d = d.view(m, num_heads, -1)
|
||
|
|
d_left = d[:, :, :left]
|
||
|
|
d_right = d[:, :, -right:]
|
||
|
|
|
||
|
|
d_mid = torch.zeros((m, num_heads, mid), dtype=d.dtype, device=d.device)
|
||
|
|
return torch.cat([d_left, d_mid, d_right], dim=2).view(m, -1)
|
||
|
|
|
||
|
|
|
||
|
|
def test_gemm_skip_head_mid() -> None:
|
||
|
|
print('Testing GEMM skip head mid:')
|
||
|
|
head_splits = (128, 64, 128)
|
||
|
|
|
||
|
|
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
|
||
|
|
out_dtype, accumulate = torch.bfloat16, False
|
||
|
|
|
||
|
|
for kernel_type in get_kernel_types(dtype=torch.float8_e4m3fn):
|
||
|
|
for m in (128, 4096):
|
||
|
|
for n, k in [(32768, 512), (8192, 512)]:
|
||
|
|
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||
|
|
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||
|
|
disable_ue8m0_cast = not use_ue8m0
|
||
|
|
|
||
|
|
a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||
|
|
d = apply_skip_head_mid(d, head_splits)
|
||
|
|
ref_d = apply_skip_head_mid(ref_d, head_splits)
|
||
|
|
|
||
|
|
deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast)
|
||
|
|
diff = calc_diff(d, ref_d)
|
||
|
|
assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}'
|
||
|
|
|
||
|
|
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast),
|
||
|
|
'fp8_gemm', suppress_kineto_output=True)
|
||
|
|
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): '
|
||
|
|
f'{t * 1e6:4.0f} us | '
|
||
|
|
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||
|
|
f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s')
|
||
|
|
print()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||
|
|
torch.backends.cudnn.allow_tf32 = True
|
||
|
|
torch.manual_seed(0)
|
||
|
|
random.seed(0)
|
||
|
|
|
||
|
|
test_gemm_skip_head_mid()
|