Make various updates and fixes (#198)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import random
|
||||
import torch
|
||||
from typing import Generator, Tuple, List
|
||||
from typing import Generator, List
|
||||
|
||||
from deep_gemm.utils import (
|
||||
align, ceil_div,
|
||||
@@ -11,7 +11,6 @@ from deep_gemm.utils import (
|
||||
|
||||
|
||||
class KernelType(enum.Enum):
|
||||
# For SM100 GEMMs
|
||||
Kernel1D1D = 0
|
||||
Kernel1D2D = 1
|
||||
KernelNoSF = 2
|
||||
@@ -48,62 +47,87 @@ def get_ue8m0_usage(kernel_type: KernelType) -> bool:
|
||||
return kernel_type.is_1d1d()
|
||||
|
||||
|
||||
def get_kernel_types(use_bf16: bool = False) -> tuple:
|
||||
if use_bf16:
|
||||
def get_kernel_types(dtype: torch.dtype) -> tuple:
|
||||
if dtype == torch.bfloat16:
|
||||
return (KernelType.KernelNoSF, )
|
||||
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)
|
||||
|
||||
# TODO: SM100 1D2D kernels are going to be deprecated
|
||||
# But if you want to test it, please use:
|
||||
# `(KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)`
|
||||
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, )
|
||||
|
||||
|
||||
def get_out_dtype() -> tuple:
|
||||
return (torch.bfloat16, ) if get_arch_major() == 9 else (torch.bfloat16, torch.float)
|
||||
def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator:
|
||||
for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
|
||||
for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
|
||||
if major_a.is_mn_major() and not allow_a_mn_major:
|
||||
continue
|
||||
if major_b.is_mn_major() and not allow_b_mn_major:
|
||||
continue
|
||||
yield major_a, major_b
|
||||
|
||||
|
||||
def get_major_ab(freeze_a: bool) -> tuple:
|
||||
# TODO: test other major-ness for SM90 BF16 GEMMs
|
||||
if get_arch_major() == 9:
|
||||
return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), )
|
||||
if freeze_a:
|
||||
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor)
|
||||
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor), \
|
||||
(MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
def enumerate_normal(dtype: torch.dtype) -> Generator:
|
||||
assert dtype in (torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
fp32_output_nk = [(256, 7168), (129280, 7168)]
|
||||
bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]
|
||||
m_fwd_list, m_bwd_list = [128, 4096], [4096, ]
|
||||
nk_list = bf16_output_nk
|
||||
|
||||
# Only BF16 GEMM needs FP32 outputs
|
||||
if dtype == torch.bfloat16:
|
||||
nk_list += fp32_output_nk
|
||||
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
# Forward
|
||||
for m in m_fwd_list:
|
||||
for n, k in nk_list:
|
||||
out_dtype = torch.float if (n, k) in fp32_output_nk else torch.bfloat16
|
||||
yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype
|
||||
|
||||
# TODO: support BF16 SM90 MN-major kernels
|
||||
if dtype == torch.bfloat16 and get_arch_major() == 9:
|
||||
continue
|
||||
|
||||
# Backward
|
||||
for m in m_bwd_list:
|
||||
for n, k in nk_list:
|
||||
override_major = MajorTypeAB.MNMajor
|
||||
override_kernel_type = kernel_type
|
||||
if get_arch_major() == 9 and dtype == torch.float8_e4m3fn:
|
||||
override_major = MajorTypeAB.KMajor
|
||||
override_kernel_type = KernelType.Kernel1D1D
|
||||
yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad
|
||||
yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad
|
||||
yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad
|
||||
|
||||
|
||||
def enumerate_normal(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
for m in (128, 4096):
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
|
||||
for major_a, major_b in get_major_ab(False):
|
||||
for out_dtype in get_out_dtype():
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
|
||||
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
|
||||
|
||||
|
||||
def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator:
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
|
||||
for major_a, major_b in get_major_ab(True):
|
||||
for major_a, major_b in get_major_ab(False, get_arch_major() > 9):
|
||||
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
|
||||
|
||||
|
||||
def enumerate_m_grouped_masked() -> Generator:
|
||||
def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator:
|
||||
max_m = 4096
|
||||
for kernel_type in get_kernel_types():
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
||||
for n, k in ((4096, 7168), (7168, 2048), ):
|
||||
yield kernel_type, num_groups, max_m, m, n, k
|
||||
|
||||
|
||||
def enumerate_k_grouped_contiguous():
|
||||
# TODO: support SM90 kernels
|
||||
if get_arch_major() == 9:
|
||||
return []
|
||||
|
||||
# Only K-major is supported for SM90
|
||||
major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 \
|
||||
else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
# Must with FP32 accumulation and 1D1D kernels
|
||||
for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64
|
||||
( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32
|
||||
(16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16
|
||||
ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
|
||||
yield num_groups, m, n, ks, expected_k_per_group
|
||||
yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group
|
||||
|
||||
|
||||
def enumerate_sf_layout():
|
||||
@@ -134,6 +158,7 @@ def enumerate_transpose():
|
||||
def generate_normal(m: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
accumulate: bool, out_dtype: torch.dtype,
|
||||
kernel_type: KernelType,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
@@ -147,7 +172,9 @@ def generate_normal(m: int, n: int, k: int,
|
||||
b = b if major_b.is_k_major() else b.T.contiguous().T
|
||||
return a, b, c, d, ref_d
|
||||
|
||||
a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \
|
||||
else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1])
|
||||
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1])
|
||||
return a_fp8, b_fp8, c, d, ref_d
|
||||
@@ -214,7 +241,7 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group:
|
||||
return a_fp8, b_fp8, masked_m, d, ref_d
|
||||
|
||||
|
||||
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool):
|
||||
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], use_ue8m0: bool):
|
||||
assert get_mk_alignment_for_contiguous_layout() % 128 == 0
|
||||
k = sum(ks)
|
||||
|
||||
@@ -232,4 +259,20 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int]
|
||||
|
||||
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
|
||||
# Transpose for K Major A/B
|
||||
if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor):
|
||||
a, sfa = a_fp8
|
||||
b, sfb = b_fp8
|
||||
new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device)
|
||||
new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device)
|
||||
prefix = 0
|
||||
for K in ks:
|
||||
new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten()
|
||||
new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten()
|
||||
prefix += K
|
||||
a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T)
|
||||
else:
|
||||
assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
|
||||
return k, a_fp8, b_fp8, c, d, ref_d
|
||||
|
||||
Reference in New Issue
Block a user