Make various updates and fixes (#198)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import random
|
||||
import torch
|
||||
from typing import Generator, Tuple, List
|
||||
from typing import Generator, List
|
||||
|
||||
from deep_gemm.utils import (
|
||||
align, ceil_div,
|
||||
@@ -11,7 +11,6 @@ from deep_gemm.utils import (
|
||||
|
||||
|
||||
class KernelType(enum.Enum):
|
||||
# For SM100 GEMMs
|
||||
Kernel1D1D = 0
|
||||
Kernel1D2D = 1
|
||||
KernelNoSF = 2
|
||||
@@ -48,62 +47,87 @@ def get_ue8m0_usage(kernel_type: KernelType) -> bool:
|
||||
return kernel_type.is_1d1d()
|
||||
|
||||
|
||||
def get_kernel_types(use_bf16: bool = False) -> tuple:
|
||||
if use_bf16:
|
||||
def get_kernel_types(dtype: torch.dtype) -> tuple:
|
||||
if dtype == torch.bfloat16:
|
||||
return (KernelType.KernelNoSF, )
|
||||
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)
|
||||
|
||||
# TODO: SM100 1D2D kernels are going to be deprecated
|
||||
# But if you want to test it, please use:
|
||||
# `(KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)`
|
||||
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, )
|
||||
|
||||
|
||||
def get_out_dtype() -> tuple:
|
||||
return (torch.bfloat16, ) if get_arch_major() == 9 else (torch.bfloat16, torch.float)
|
||||
def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator:
|
||||
for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
|
||||
for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
|
||||
if major_a.is_mn_major() and not allow_a_mn_major:
|
||||
continue
|
||||
if major_b.is_mn_major() and not allow_b_mn_major:
|
||||
continue
|
||||
yield major_a, major_b
|
||||
|
||||
|
||||
def get_major_ab(freeze_a: bool) -> tuple:
|
||||
# TODO: test other major-ness for SM90 BF16 GEMMs
|
||||
if get_arch_major() == 9:
|
||||
return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), )
|
||||
if freeze_a:
|
||||
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor)
|
||||
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor), \
|
||||
(MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
def enumerate_normal(dtype: torch.dtype) -> Generator:
|
||||
assert dtype in (torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
fp32_output_nk = [(256, 7168), (129280, 7168)]
|
||||
bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]
|
||||
m_fwd_list, m_bwd_list = [128, 4096], [4096, ]
|
||||
nk_list = bf16_output_nk
|
||||
|
||||
# Only BF16 GEMM needs FP32 outputs
|
||||
if dtype == torch.bfloat16:
|
||||
nk_list += fp32_output_nk
|
||||
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
# Forward
|
||||
for m in m_fwd_list:
|
||||
for n, k in nk_list:
|
||||
out_dtype = torch.float if (n, k) in fp32_output_nk else torch.bfloat16
|
||||
yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype
|
||||
|
||||
# TODO: support BF16 SM90 MN-major kernels
|
||||
if dtype == torch.bfloat16 and get_arch_major() == 9:
|
||||
continue
|
||||
|
||||
# Backward
|
||||
for m in m_bwd_list:
|
||||
for n, k in nk_list:
|
||||
override_major = MajorTypeAB.MNMajor
|
||||
override_kernel_type = kernel_type
|
||||
if get_arch_major() == 9 and dtype == torch.float8_e4m3fn:
|
||||
override_major = MajorTypeAB.KMajor
|
||||
override_kernel_type = KernelType.Kernel1D1D
|
||||
yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad
|
||||
yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad
|
||||
yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad
|
||||
|
||||
|
||||
def enumerate_normal(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
for m in (128, 4096):
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
|
||||
for major_a, major_b in get_major_ab(False):
|
||||
for out_dtype in get_out_dtype():
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
|
||||
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
|
||||
|
||||
|
||||
def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator:
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
|
||||
for major_a, major_b in get_major_ab(True):
|
||||
for major_a, major_b in get_major_ab(False, get_arch_major() > 9):
|
||||
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
|
||||
|
||||
|
||||
def enumerate_m_grouped_masked() -> Generator:
|
||||
def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator:
|
||||
max_m = 4096
|
||||
for kernel_type in get_kernel_types():
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
||||
for n, k in ((4096, 7168), (7168, 2048), ):
|
||||
yield kernel_type, num_groups, max_m, m, n, k
|
||||
|
||||
|
||||
def enumerate_k_grouped_contiguous():
|
||||
# TODO: support SM90 kernels
|
||||
if get_arch_major() == 9:
|
||||
return []
|
||||
|
||||
# Only K-major is supported for SM90
|
||||
major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 \
|
||||
else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
# Must with FP32 accumulation and 1D1D kernels
|
||||
for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64
|
||||
( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32
|
||||
(16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16
|
||||
ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
|
||||
yield num_groups, m, n, ks, expected_k_per_group
|
||||
yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group
|
||||
|
||||
|
||||
def enumerate_sf_layout():
|
||||
@@ -134,6 +158,7 @@ def enumerate_transpose():
|
||||
def generate_normal(m: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
accumulate: bool, out_dtype: torch.dtype,
|
||||
kernel_type: KernelType,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
@@ -147,7 +172,9 @@ def generate_normal(m: int, n: int, k: int,
|
||||
b = b if major_b.is_k_major() else b.T.contiguous().T
|
||||
return a, b, c, d, ref_d
|
||||
|
||||
a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \
|
||||
else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1])
|
||||
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1])
|
||||
return a_fp8, b_fp8, c, d, ref_d
|
||||
@@ -214,7 +241,7 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group:
|
||||
return a_fp8, b_fp8, masked_m, d, ref_d
|
||||
|
||||
|
||||
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool):
|
||||
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], use_ue8m0: bool):
|
||||
assert get_mk_alignment_for_contiguous_layout() % 128 == 0
|
||||
k = sum(ks)
|
||||
|
||||
@@ -232,4 +259,20 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int]
|
||||
|
||||
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
|
||||
# Transpose for K Major A/B
|
||||
if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor):
|
||||
a, sfa = a_fp8
|
||||
b, sfb = b_fp8
|
||||
new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device)
|
||||
new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device)
|
||||
prefix = 0
|
||||
for K in ks:
|
||||
new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten()
|
||||
new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten()
|
||||
prefix += K
|
||||
a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T)
|
||||
else:
|
||||
assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
|
||||
return k, a_fp8, b_fp8, c, d, ref_d
|
||||
|
||||
64
tests/test_attention.py
Normal file
64
tests/test_attention.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import bench_kineto, calc_diff, count_bytes
|
||||
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8
|
||||
|
||||
from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB
|
||||
|
||||
|
||||
def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]):
|
||||
left, mid, right = head_splits
|
||||
m, n = d.shape
|
||||
assert n % (left + right) == 0
|
||||
num_heads = n // (left + right)
|
||||
|
||||
# Split and insert padding tensor
|
||||
d = d.view(m, num_heads, -1)
|
||||
d_left = d[:, :, :left]
|
||||
d_right = d[:, :, -right:]
|
||||
|
||||
d_mid = torch.zeros((m, num_heads, mid), dtype=d.dtype, device=d.device)
|
||||
return torch.cat([d_left, d_mid, d_right], dim=2).view(m, -1)
|
||||
|
||||
|
||||
def test_gemm_skip_head_mid() -> None:
|
||||
print('Testing GEMM skip head mid:')
|
||||
head_splits = (128, 64, 128)
|
||||
|
||||
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
|
||||
out_dtype, accumulate = torch.bfloat16, False
|
||||
|
||||
for kernel_type in get_kernel_types(dtype=torch.float8_e4m3fn):
|
||||
for m in (128, 4096):
|
||||
for n, k in [(32768, 512), (8192, 512)]:
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
|
||||
a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||||
d = apply_skip_head_mid(d, head_splits)
|
||||
ref_d = apply_skip_head_mid(ref_d, head_splits)
|
||||
|
||||
deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}'
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast),
|
||||
'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
test_gemm_skip_head_mid()
|
||||
@@ -7,6 +7,7 @@ from deep_gemm.testing import (
|
||||
calc_diff, count_bytes
|
||||
)
|
||||
from generators import (
|
||||
get_arch_major,
|
||||
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, generate_normal,
|
||||
generate_m_grouped_contiguous, generate_m_grouped_masked
|
||||
)
|
||||
@@ -14,14 +15,18 @@ from generators import (
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(use_bf16=True):
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16):
|
||||
# TODO: support accumulation for SM90 BF16 GEMM
|
||||
if get_arch_major() == 9 and accumulate:
|
||||
continue
|
||||
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
acc_opt = f'acc={int(accumulate)}'
|
||||
|
||||
for test_alias in (False, True):
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True)
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
|
||||
func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}'
|
||||
if test_alias:
|
||||
a = a if major_a.is_k_major() else a.T
|
||||
@@ -31,28 +36,22 @@ def test_gemm() -> None:
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.0001, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, '
|
||||
f'{diff:.5f}, alias={test_alias}')
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True)
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
|
||||
|
||||
cublas_t = 0
|
||||
t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True)
|
||||
if accumulate == 0 and out_dtype == torch.bfloat16:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cublas_t = bench_kineto(lambda: a @ b.T, 'nvjet', suppress_kineto_output=True)
|
||||
except Exception:
|
||||
pass
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:5.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
|
||||
f'{cublas_t / t:.2f}x cuBLAS')
|
||||
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing m-grouped contiguous GEMM:')
|
||||
|
||||
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(use_bf16=True):
|
||||
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
|
||||
@@ -85,7 +84,7 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing m-grouped masked GEMM:')
|
||||
|
||||
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
||||
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked():
|
||||
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.bfloat16):
|
||||
# Test correctness
|
||||
for i in range(10):
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
|
||||
@@ -111,6 +110,27 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
print()
|
||||
|
||||
|
||||
def test_cublaslt_gemm() -> None:
|
||||
print('Testing cuBLASLt GEMM:')
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
acc_opt = f'acc={int(accumulate)}'
|
||||
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
|
||||
deep_gemm.cublaslt_gemm_nt(a, b, d, c=c)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 5e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})'
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), 'nvjet', suppress_kineto_output=True,)
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:5.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
@@ -121,5 +141,9 @@ if __name__ == '__main__':
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
# TODO: support SM100
|
||||
if get_arch_major() == 9:
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
|
||||
test_cublaslt_gemm()
|
||||
|
||||
85
tests/test_einsum.py
Normal file
85
tests/test_einsum.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import random
|
||||
import torch
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import (
|
||||
bench, bench_kineto,
|
||||
calc_diff, count_bytes
|
||||
)
|
||||
|
||||
|
||||
def test_bmk_bnk_mn() -> None:
|
||||
print('Testing "bmk, bnk -> mn":')
|
||||
for s in (129, 4096, 8192):
|
||||
for m, n, k in [(128, 384, 128), (256, 256, 256), (384, 128, 384)]:
|
||||
for dtype in (torch.float, torch.bfloat16):
|
||||
a = torch.randn((s, m, k), dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn((s, n, k), dtype=torch.bfloat16, device='cuda')
|
||||
d = torch.randn((m, n), dtype=dtype, device='cuda')
|
||||
c = d if dtype == torch.float else None
|
||||
|
||||
# Test correctness
|
||||
ref_d = (c if dtype == torch.float else 0) + torch.bmm(a.float(), b.float().mT).sum(0)
|
||||
deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c)
|
||||
assert calc_diff(d, ref_d) < 1e-5
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c), 'bmn_bnk_mn_gemm_impl', suppress_kineto_output=True)
|
||||
print(f' > Perf (b={s:4.0f}, {m=}, {n=}, {k=}, {"FP32" if dtype == torch.float else "BF16"}): ',
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * s * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b) + (d.numel() * 4)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_bhr_hdr_bhd():
|
||||
print('Testing "bhr, hdr -> bhd":')
|
||||
for b in (128, 4096, 8192):
|
||||
for h, r, d in [(128, 512, 128)]:
|
||||
x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16)
|
||||
fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16)
|
||||
y = fy[:, :, :r]
|
||||
ref_z = torch.einsum('bhr,hdr->bhd', x, y)
|
||||
z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.einsum('bhr,hdr->bhd', x, y, z)
|
||||
assert calc_diff(z, ref_z) < 1e-10
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'nvjet', suppress_kineto_output=True)
|
||||
print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ',
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | '
|
||||
f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_bhd_hdr_bhr():
|
||||
print('Testing "bhd, hdr -> bhr":')
|
||||
for b in (128, 4096, 8192):
|
||||
for h, r, d in [(128, 512, 128)]:
|
||||
x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16)
|
||||
fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16)
|
||||
y = fy[:, :, :r]
|
||||
ref_z = torch.einsum('bhd,hdr->bhr', x, y)
|
||||
z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.einsum('bhd,hdr->bhr', x, y, z)
|
||||
assert calc_diff(z, ref_z) < 1e-10
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'nvjet', suppress_kineto_output=True)
|
||||
print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ',
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | '
|
||||
f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_bmk_bnk_mn()
|
||||
test_bhr_hdr_bhd()
|
||||
test_bhd_hdr_bhr()
|
||||
@@ -10,7 +10,7 @@ from deep_gemm.testing import (
|
||||
)
|
||||
|
||||
from generators import (
|
||||
KernelType, get_ue8m0_usage,
|
||||
KernelType, get_arch_major, get_ue8m0_usage,
|
||||
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
|
||||
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
|
||||
)
|
||||
@@ -18,7 +18,7 @@ from generators import (
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal():
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
@@ -26,42 +26,35 @@ def test_gemm() -> None:
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None
|
||||
|
||||
for test_alias in (False, True):
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0)
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||||
func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}'
|
||||
if test_alias:
|
||||
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
|
||||
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
|
||||
assert a[0].is_contiguous() and b[0].is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
|
||||
f'{diff:.5f}, alias={test_alias}')
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0)
|
||||
|
||||
# Test launch overhead
|
||||
launch_start_t = time.time_ns()
|
||||
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
launch_end_t = time.time_ns()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s')
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||||
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe),
|
||||
'fp8_gemm', suppress_kineto_output=True)
|
||||
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:4.0f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
|
||||
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing m-grouped contiguous GEMM:')
|
||||
|
||||
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous():
|
||||
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
@@ -86,7 +79,7 @@ def test_m_grouped_gemm_contiguous() -> None:
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}): '
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
|
||||
@@ -97,7 +90,7 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing m-grouped masked GEMM:')
|
||||
|
||||
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
||||
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked():
|
||||
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn):
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
@@ -130,26 +123,31 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
def test_k_grouped_gemm_contiguous() -> None:
|
||||
print('Testing k-grouped contiguous GEMM:')
|
||||
|
||||
for num_groups, m, n, ks, expected_k_per_group in enumerate_k_grouped_contiguous():
|
||||
k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \
|
||||
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous
|
||||
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous():
|
||||
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
|
||||
|
||||
for test_empty_groups in (False, True):
|
||||
new_ks = copy.deepcopy(ks)
|
||||
if test_empty_groups:
|
||||
if test_empty_groups and len(ks) > 1:
|
||||
new_ks[random.randint(0, num_groups - 1)] = 0
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, new_ks, use_ue8m0=use_ue8m0)
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0)
|
||||
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
|
||||
deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c=c)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {i=}, {diff:.5f}'
|
||||
k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c)
|
||||
|
||||
do_check = True
|
||||
if do_check:
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}'
|
||||
|
||||
# Test performance
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, ks, use_ue8m0=use_ue8m0)
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0)
|
||||
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=c)
|
||||
k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '
|
||||
|
||||
Reference in New Issue
Block a user