Add cutlass support for blackwell fp8 blockwise gemm (#14383)

Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
Shu Wang
2025-05-08 17:09:55 -05:00
committed by GitHub
parent 4f605a6de5
commit 376786fac1
11 changed files with 332 additions and 64 deletions

View File

@@ -57,6 +57,16 @@ def apply_w8a8_block_fp8_linear(
or br not in (1, weight.shape[0])):
shape_supported_by_cutlass = False
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
rows, cols = input_2d.shape
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
# optimal tensor core usage. Can be removed when targeting platforms
# without this constraint.
should_pad = current_platform.has_device_capability(
100) and rows % 4 != 0
if should_pad:
input_2d = torch.nn.functional.pad(input_2d,
(0, 0, 0, 4 - (rows % 4)),
value=0).contiguous()
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
@@ -65,6 +75,8 @@ def apply_w8a8_block_fp8_linear(
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.T)
if should_pad:
output = output[:rows, :]
else:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],