Files
DeepGEMM/tests/test_einsum.py
2025-09-25 16:19:07 +08:00

86 lines
3.4 KiB
Python

import random
import torch
import deep_gemm
from deep_gemm.testing import (
bench, bench_kineto,
calc_diff, count_bytes
)
def test_bmk_bnk_mn() -> None:
print('Testing "bmk, bnk -> mn":')
for s in (129, 4096, 8192):
for m, n, k in [(128, 384, 128), (256, 256, 256), (384, 128, 384)]:
for dtype in (torch.float, torch.bfloat16):
a = torch.randn((s, m, k), dtype=torch.bfloat16, device='cuda')
b = torch.randn((s, n, k), dtype=torch.bfloat16, device='cuda')
d = torch.randn((m, n), dtype=dtype, device='cuda')
c = d if dtype == torch.float else None
# Test correctness
ref_d = (c if dtype == torch.float else 0) + torch.bmm(a.float(), b.float().mT).sum(0)
deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c)
assert calc_diff(d, ref_d) < 1e-5
t = bench_kineto(lambda: deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c), 'bmn_bnk_mn_gemm_impl', suppress_kineto_output=True)
print(f' > Perf (b={s:4.0f}, {m=}, {n=}, {k=}, {"FP32" if dtype == torch.float else "BF16"}): ',
f'{t * 1e6:4.0f} us | '
f'{2 * s * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b) + (d.numel() * 4)) / 1e9 / t:4.0f} GB/s')
print()
def test_bhr_hdr_bhd():
print('Testing "bhr, hdr -> bhd":')
for b in (128, 4096, 8192):
for h, r, d in [(128, 512, 128)]:
x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16)
fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16)
y = fy[:, :, :r]
ref_z = torch.einsum('bhr,hdr->bhd', x, y)
z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16)
deep_gemm.einsum('bhr,hdr->bhd', x, y, z)
assert calc_diff(z, ref_z) < 1e-10
t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'nvjet', suppress_kineto_output=True)
print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ',
f'{t * 1e6:4.0f} us | '
f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | '
f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s')
print()
def test_bhd_hdr_bhr():
print('Testing "bhd, hdr -> bhr":')
for b in (128, 4096, 8192):
for h, r, d in [(128, 512, 128)]:
x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16)
fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16)
y = fy[:, :, :r]
ref_z = torch.einsum('bhd,hdr->bhr', x, y)
z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16)
deep_gemm.einsum('bhd,hdr->bhr', x, y, z)
assert calc_diff(z, ref_z) < 1e-10
t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'nvjet', suppress_kineto_output=True)
print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ',
f'{t * 1e6:4.0f} us | '
f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | '
f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_bmk_bnk_mn()
test_bhr_hdr_bhd()
test_bhd_hdr_bhr()