[Perf] Use upstream CUTLASS for SM90 Block FP8 kernel (#23280)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -40,11 +40,14 @@ def cutlass_scaled_mm(
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return ops.cutlass_scaled_mm(A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.T)
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
||||
scale_b=Bs if block_size is not None
|
||||
and current_platform.is_device_capability(90) else Bs.T)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
@@ -152,35 +155,32 @@ def apply_w8a8_block_fp8_linear(
|
||||
output += bias
|
||||
return output.to(dtype=output_dtype).view(*output_shape)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.has_device_capability(100):
|
||||
|
||||
use_cutlass = cutlass_block_fp8_supported and (
|
||||
cdiv(weight.shape[0], 128) == weight_scale.shape[0]
|
||||
and cdiv(weight.shape[1], 128) == weight_scale.shape[1])
|
||||
else:
|
||||
# TODO: update this after switching to public sm90 block scale gemm
|
||||
# as it also supports weight.shape % 128 != 0
|
||||
use_cutlass = cutlass_block_fp8_supported and (
|
||||
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
else:
|
||||
use_cutlass = False
|
||||
|
||||
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
||||
use_cutlass, use_aiter_and_is_supported)
|
||||
if use_cutlass:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||
if cutlass_block_fp8_supported:
|
||||
num_pad = 0
|
||||
if current_platform.is_device_capability(90):
|
||||
# pad first dimension to be divisible by 4 due to
|
||||
# cutlass blockwise gemm limitation for hopper
|
||||
num_pad = 4 - (input_2d.shape[0] % 4)
|
||||
if num_pad > 0:
|
||||
input_2d = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, num_pad),
|
||||
"constant", 0)
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
|
||||
if num_pad > 0:
|
||||
output = output[:-num_pad]
|
||||
else:
|
||||
if use_aiter_and_is_supported:
|
||||
q_input, x_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||
input_2d, block_size[1], column_major_scales=False)
|
||||
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
|
||||
Reference in New Issue
Block a user