Make various updates and fixes: (#164)

- Add BF16 support for SM90 and SM100
- Refactor Python APIs
- Other fixes and code refactoring
This commit is contained in:
Ray Wang
2025-08-15 18:32:35 +08:00
committed by GitHub
parent 3254b758e2
commit f85ec649d7
34 changed files with 2293 additions and 495 deletions

View File

@@ -1,16 +1,18 @@
import time
import torch
import random
from deep_gemm.testing import bench_kineto, count_bytes
from deep_gemm.testing import bench_kineto, count_bytes, calc_diff
from deep_gemm.utils import (
align, ceil_div,
per_token_cast_to_fp8, per_channel_cast_to_fp8,
get_tma_aligned_size,
get_mn_major_tma_aligned_tensor,
get_mn_major_tma_aligned_packed_ue8m0_tensor,
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
)
from generators import (
enumerate_transpose,
enumerate_sf_layout,
enumerate_k_grouped_sf_layout
)
@@ -43,29 +45,39 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) ->
def test_sf_layout_kernels() -> None:
print('Testing SF layout kernels:')
for mn, k, with_transpose, num_groups in enumerate_sf_layout():
for mn, k, with_transpose, use_ue8m0, num_groups in enumerate_sf_layout():
x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda')
x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=True)
x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0)
fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1)
fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2)
# Correctness
packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf)
assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}'
assert packed_sf.shape == ref_packed_sf.shape
assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())])
# Test launch overhead
launch_start_t = time.time_ns()
get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
launch_end_t = time.time_ns()
if use_ue8m0:
impl, name = get_mn_major_tma_aligned_packed_ue8m0_tensor, 'pack_fp32_into_ue8m0'
packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf)
assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}'
assert packed_sf.shape == ref_packed_sf.shape
assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())])
else:
impl, name = get_mn_major_tma_aligned_tensor, 'transpose'
transposed_sf = get_mn_major_tma_aligned_tensor(fp32_sf)
tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, 128)
if num_groups > 1:
assert transposed_sf.size(0) == num_groups
assert transposed_sf.stride(0) == tma_aligned_mn * sf_k
assert transposed_sf.shape[-2:] == (mn, sf_k)
assert transposed_sf.stride()[-2:] == (1, tma_aligned_mn)
assert torch.equal(fp32_sf, transposed_sf)
# Performance
t = bench_kineto(lambda: get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf), 'pack_fp32_into_ue8m0')
print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}): '
f'launch {(launch_end_t - launch_start_t) / 1e3:3.0f} us | {t * 1e6:4.0f} us | '
f'{count_bytes(fp32_sf, packed_sf) / 1e9 / t:4.0f} GB/s')
try:
t = bench_kineto(lambda: impl(fp32_sf), name)
except AssertionError as e:
# Some cases may fallback to PyTorch impl
t = 0
print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}): '
f'{t * 1e6:4.0f} us | {count_bytes(fp32_sf, impl(fp32_sf)) / 1e9 / t if t else 0:4.0f} GB/s')
print()