From 350989eef34bc0e48a33a7fcae4656ebc53fe4c0 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 15 May 2025 16:48:32 +0800 Subject: [PATCH] Unify `ceil_div`s --- deep_gemm/jit_kernels/gemm.py | 8 ++--- deep_gemm/jit_kernels/m_grouped_gemm.py | 10 +++--- deep_gemm/jit_kernels/wgrad_gemm.py | 42 ++++++++++++------------- tests/test_core.py | 8 ++--- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 343e84a..5f7a123 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -179,15 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Type and shape checks assert m == m_ and n == n_ and k == k_ assert n > 0 and k > 0 - assert lhs_scales.shape == (m, (k + 127) // 128) - assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) + assert lhs_scales.shape == (m, ceil_div(k, 128)) + assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128)) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.bfloat16 assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 # LHS scales must be transposed for TMA loads, but not for RHS scales - # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels + # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) assert rhs_scales.is_contiguous() @@ -196,7 +196,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], return # K must be aligned to 128 - aligned_k = (k + 127) // 128 * 128 + aligned_k = ceil_div(k, 128) * 128 # Auto-tuning with compilation num_sms = get_num_sms() diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 73fd2f1..c2f2d93 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -7,7 +7,7 @@ from .runtime import ( FP8GemmRuntime, GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .utils import get_col_major_tma_aligned_tensor, get_num_sms +from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Type and shape checks assert m == m_ == m__ and k == k_ and n == n_ - assert lhs_scales.shape == (m, (k + 127) // 128) - assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) + assert lhs_scales.shape == (m, ceil_div(k, 128)) + assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.bfloat16 @@ -142,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] assert num_groups == num_groups_ == num_groups__ == num_groups___ assert m == m_ and n == n_ and k == k_ assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert lhs_scales.shape == (num_groups, m, (k + 127) // 128) - assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) + assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128)) + assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.bfloat16 diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index 8a38578..658f005 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -7,7 +7,7 @@ from .runtime import ( make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_desc) from .gemm import get_best_configs -from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size +from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -40,41 +40,39 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Type and shape checks assert m == m_ and n == n_ and k == k_ assert n > 0 and m > 0 - assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m) - assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n) + assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m) + assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.float assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 # LHS and RHS scales must be transposed for TMA load - # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels - if lhs_scales.shape == ((k + 127) // 128, m): - lhs_scales = lhs_scales.permute(1, 0) - assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m - else: - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert lhs_scales.stride(0) == 1 - - if rhs_scales.shape == ((k + 127) // 128, n): - rhs_scales = rhs_scales.permute(1, 0) - assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n - else: - rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales) - assert rhs_scales.stride(0) == 1 + # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels + def get_valid_scales(scales: torch.Tensor, mn: int): + if scales.shape == (ceil_div(k, 128), mn): + # For k-grouped GEMMs + scales = scales.permute(1, 0) + assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn + else: + scales = get_col_major_tma_aligned_tensor(scales) + return scales + + lhs_scales = get_valid_scales(lhs_scales, m) + rhs_scales = get_valid_scales(rhs_scales, n) # Do nothing if `k` is zero if k == 0: return # K must be aligned to 128 - aligned_k = (k + 127) // 128 * 128 + aligned_k = ceil_div(k, 128) * 128 # Auto-tuning with compilation num_sms = get_num_sms() num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) - num_last_stages = (k + 127) // 128 % num_stages + num_last_stages = ceil_div(k, 128) % num_stages block_k = 128 num_tma_threads = 128 num_math_threads_per_group = 128 @@ -151,10 +149,10 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], k = batch_sizes[i] lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k) rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k) - lhs_scales_slice = lhs_scales[scales_offset:scales_offset + (k + 127) // 128] - rhs_scales_slice = rhs_scales[scales_offset:scales_offset + (k + 127) // 128] + lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] + rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i]) lhs_offset += m * k rhs_offset += n * k - scales_offset += (k + 127) // 128 + scales_offset += ceil_div(k, 128) diff --git a/tests/test_core.py b/tests/test_core.py index 36c1c34..03038db 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -71,7 +71,7 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: assert m % 4 == 0, f'TMA alignment error: {m}' x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) for i in range(num_groups): y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) @@ -87,7 +87,7 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \ assert m % 4 == 0, f'TMA alignment error: {m}' x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) for i in range(num_groups): x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) @@ -137,7 +137,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) - total_scale_factors = sum((k + 127) // 128 for k in k_sizes) + total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes) x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) @@ -150,7 +150,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) - num_scales = (k + 127) // 128 + num_scales = ceil_div(k, 128) x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)