Make various updates and fixes: (#164)
- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
from deep_gemm.testing import bench_kineto, count_bytes
|
||||
from deep_gemm.testing import bench_kineto, count_bytes, calc_diff
|
||||
from deep_gemm.utils import (
|
||||
align, ceil_div,
|
||||
per_token_cast_to_fp8, per_channel_cast_to_fp8,
|
||||
get_tma_aligned_size,
|
||||
get_mn_major_tma_aligned_tensor,
|
||||
get_mn_major_tma_aligned_packed_ue8m0_tensor,
|
||||
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
|
||||
)
|
||||
|
||||
from generators import (
|
||||
enumerate_transpose,
|
||||
enumerate_sf_layout,
|
||||
enumerate_k_grouped_sf_layout
|
||||
)
|
||||
@@ -43,29 +45,39 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) ->
|
||||
|
||||
def test_sf_layout_kernels() -> None:
|
||||
print('Testing SF layout kernels:')
|
||||
for mn, k, with_transpose, num_groups in enumerate_sf_layout():
|
||||
for mn, k, with_transpose, use_ue8m0, num_groups in enumerate_sf_layout():
|
||||
x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda')
|
||||
x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=True)
|
||||
x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0)
|
||||
fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1)
|
||||
fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2)
|
||||
|
||||
# Correctness
|
||||
packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
|
||||
ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf)
|
||||
assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}'
|
||||
assert packed_sf.shape == ref_packed_sf.shape
|
||||
assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())])
|
||||
|
||||
# Test launch overhead
|
||||
launch_start_t = time.time_ns()
|
||||
get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
|
||||
launch_end_t = time.time_ns()
|
||||
if use_ue8m0:
|
||||
impl, name = get_mn_major_tma_aligned_packed_ue8m0_tensor, 'pack_fp32_into_ue8m0'
|
||||
packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
|
||||
ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf)
|
||||
assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}'
|
||||
assert packed_sf.shape == ref_packed_sf.shape
|
||||
assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())])
|
||||
else:
|
||||
impl, name = get_mn_major_tma_aligned_tensor, 'transpose'
|
||||
transposed_sf = get_mn_major_tma_aligned_tensor(fp32_sf)
|
||||
tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, 128)
|
||||
if num_groups > 1:
|
||||
assert transposed_sf.size(0) == num_groups
|
||||
assert transposed_sf.stride(0) == tma_aligned_mn * sf_k
|
||||
assert transposed_sf.shape[-2:] == (mn, sf_k)
|
||||
assert transposed_sf.stride()[-2:] == (1, tma_aligned_mn)
|
||||
assert torch.equal(fp32_sf, transposed_sf)
|
||||
|
||||
# Performance
|
||||
t = bench_kineto(lambda: get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf), 'pack_fp32_into_ue8m0')
|
||||
print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}): '
|
||||
f'launch {(launch_end_t - launch_start_t) / 1e3:3.0f} us | {t * 1e6:4.0f} us | '
|
||||
f'{count_bytes(fp32_sf, packed_sf) / 1e9 / t:4.0f} GB/s')
|
||||
try:
|
||||
t = bench_kineto(lambda: impl(fp32_sf), name)
|
||||
except AssertionError as e:
|
||||
# Some cases may fallback to PyTorch impl
|
||||
t = 0
|
||||
print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}): '
|
||||
f'{t * 1e6:4.0f} us | {count_bytes(fp32_sf, impl(fp32_sf)) / 1e9 / t if t else 0:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user