Make various updates and fixes: (#164)
- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
This commit is contained in:
@@ -59,6 +59,7 @@ def get_out_dtype() -> tuple:
|
||||
|
||||
|
||||
def get_major_ab(freeze_a: bool) -> tuple:
|
||||
# TODO: test other major-ness for SM90 BF16 GEMMs
|
||||
if get_arch_major() == 9:
|
||||
return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), )
|
||||
if freeze_a:
|
||||
@@ -70,15 +71,15 @@ def get_major_ab(freeze_a: bool) -> tuple:
|
||||
def enumerate_normal(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
for m in (128, 4096):
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048), (129280, 7168)]:
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
|
||||
for major_a, major_b in get_major_ab(False):
|
||||
for out_dtype in get_out_dtype():
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or not kernel_type.is_1d1d() else (False, True):
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
|
||||
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
|
||||
|
||||
|
||||
def enumerate_m_grouped_contiguous() -> Generator:
|
||||
for kernel_type in get_kernel_types():
|
||||
def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
|
||||
for major_a, major_b in get_major_ab(True):
|
||||
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
|
||||
@@ -106,15 +107,12 @@ def enumerate_k_grouped_contiguous():
|
||||
|
||||
|
||||
def enumerate_sf_layout():
|
||||
for with_transpose in (True, False):
|
||||
for mn in (4096, 4097, 8192):
|
||||
for k in (128, 7168, 7296):
|
||||
for num_groups in (1, 2, 4) if with_transpose else (1, ):
|
||||
if num_groups > 1 and (mn * ceil_div(k, 128)) % 4 != 0:
|
||||
continue
|
||||
if not with_transpose and mn % 4 != 0:
|
||||
continue
|
||||
yield mn, k, with_transpose, num_groups
|
||||
for use_ue8m0 in (False, True):
|
||||
for with_transpose in (True, False):
|
||||
for mn in (4096, 4097, 8192):
|
||||
for k in (128, 7168, 7296):
|
||||
for num_groups in (1, 2, 4):
|
||||
yield mn, k, with_transpose, use_ue8m0, num_groups
|
||||
|
||||
|
||||
def enumerate_k_grouped_sf_layout():
|
||||
@@ -126,6 +124,13 @@ def enumerate_k_grouped_sf_layout():
|
||||
yield mn, ks, num_groups
|
||||
|
||||
|
||||
def enumerate_transpose():
|
||||
for mn in (64, 4096, 16384):
|
||||
for delta in (0, 101, 202, 303):
|
||||
for k in (128, 1024, 4096, 9984, 16384):
|
||||
yield mn + delta, k
|
||||
|
||||
|
||||
def generate_normal(m: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
accumulate: bool, out_dtype: torch.dtype,
|
||||
@@ -149,8 +154,8 @@ def generate_normal(m: int, n: int, k: int,
|
||||
|
||||
|
||||
def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB, use_ue8m0: bool) -> \
|
||||
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
|
||||
aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms]
|
||||
m = sum(aligned_ms)
|
||||
@@ -171,6 +176,10 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n:
|
||||
start = aligned_end
|
||||
ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d)
|
||||
|
||||
if use_bf16:
|
||||
b = b if major_b.is_k_major() else b.mT.contiguous().mT
|
||||
return m, a, b, m_indices, d, ref_d
|
||||
|
||||
assert major_a.is_k_major()
|
||||
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn),
|
||||
@@ -181,24 +190,27 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n:
|
||||
return m, a_fp8, b_fp8, m_indices, d, ref_d
|
||||
|
||||
|
||||
def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, use_ue8m0: bool) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_d = torch.einsum('gmk,gnk->gmn', a, b)
|
||||
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
|
||||
assert masked_m.amax().item() <= max_m
|
||||
|
||||
if use_bf16:
|
||||
return a, b, masked_m, d, ref_d
|
||||
|
||||
a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float))
|
||||
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0)
|
||||
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)
|
||||
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
|
||||
assert masked_m.amax().item() <= max_m
|
||||
|
||||
return a_fp8, b_fp8, masked_m, d, ref_d
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user