[Perf] Use upstream CUTLASS for SM90 Block FP8 kernel (#23280)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Michael Goin
2025-09-11 18:43:14 -04:00
committed by GitHub
parent d4fd2768ef
commit c3aea10dc8
13 changed files with 221 additions and 1260 deletions

View File

@@ -40,11 +40,14 @@ def cutlass_scaled_mm(
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return ops.cutlass_scaled_mm(A,
B.T,
out_dtype=output_dtype,
scale_a=As,
scale_b=Bs.T)
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None
and current_platform.is_device_capability(90) else Bs.T)
def rocm_aiter_gemm_w8a8_blockscale_impl(
@@ -152,35 +155,32 @@ def apply_w8a8_block_fp8_linear(
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
if current_platform.is_cuda():
if current_platform.has_device_capability(100):
use_cutlass = cutlass_block_fp8_supported and (
cdiv(weight.shape[0], 128) == weight_scale.shape[0]
and cdiv(weight.shape[1], 128) == weight_scale.shape[1])
else:
# TODO: update this after switching to public sm90 block scale gemm
# as it also supports weight.shape % 128 != 0
use_cutlass = cutlass_block_fp8_supported and (
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
else:
use_cutlass = False
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported)
if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
cutlass_block_fp8_supported, use_aiter_and_is_supported)
if cutlass_block_fp8_supported:
num_pad = 0
if current_platform.is_device_capability(90):
# pad first dimension to be divisible by 4 due to
# cutlass blockwise gemm limitation for hopper
num_pad = 4 - (input_2d.shape[0] % 4)
if num_pad > 0:
input_2d = torch.nn.functional.pad(input_2d,
(0, 0, 0, num_pad),
"constant", 0)
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if num_pad > 0:
output = output[:-num_pad]
else:
if use_aiter_and_is_supported:
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
input_2d, block_size[1], column_major_scales=False)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)