Unify ceil_divs
This commit is contained in:
@@ -71,7 +71,7 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
|
||||
@@ -87,7 +87,7 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
@@ -137,7 +137,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
||||
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
|
||||
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
|
||||
|
||||
total_scale_factors = sum((k + 127) // 128 for k in k_sizes)
|
||||
total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes)
|
||||
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
|
||||
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
|
||||
|
||||
@@ -150,7 +150,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
||||
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
|
||||
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
|
||||
|
||||
num_scales = (k + 127) // 128
|
||||
num_scales = ceil_div(k, 128)
|
||||
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
|
||||
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user