Files
DeepGEMM/tests/test_attention.py
2025-09-25 16:19:07 +08:00

65 lines
2.5 KiB
Python

import random
import torch
from typing import Tuple
import deep_gemm
from deep_gemm.testing import bench_kineto, calc_diff, count_bytes
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8
from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB
def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]):
left, mid, right = head_splits
m, n = d.shape
assert n % (left + right) == 0
num_heads = n // (left + right)
# Split and insert padding tensor
d = d.view(m, num_heads, -1)
d_left = d[:, :, :left]
d_right = d[:, :, -right:]
d_mid = torch.zeros((m, num_heads, mid), dtype=d.dtype, device=d.device)
return torch.cat([d_left, d_mid, d_right], dim=2).view(m, -1)
def test_gemm_skip_head_mid() -> None:
print('Testing GEMM skip head mid:')
head_splits = (128, 64, 128)
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
out_dtype, accumulate = torch.bfloat16, False
for kernel_type in get_kernel_types(dtype=torch.float8_e4m3fn):
for m in (128, 4096):
for n, k in [(32768, 512), (8192, 512)]:
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
d = apply_skip_head_mid(d, head_splits)
ref_d = apply_skip_head_mid(ref_d, head_splits)
deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}'
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast),
'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
test_gemm_skip_head_mid()