Add cutlass support for blackwell fp8 blockwise gemm (#14383)
Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
@@ -57,6 +57,16 @@ def apply_w8a8_block_fp8_linear(
|
||||
or br not in (1, weight.shape[0])):
|
||||
shape_supported_by_cutlass = False
|
||||
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
|
||||
rows, cols = input_2d.shape
|
||||
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
|
||||
# optimal tensor core usage. Can be removed when targeting platforms
|
||||
# without this constraint.
|
||||
should_pad = current_platform.has_device_capability(
|
||||
100) and rows % 4 != 0
|
||||
if should_pad:
|
||||
input_2d = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, 4 - (rows % 4)),
|
||||
value=0).contiguous()
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
@@ -65,6 +75,8 @@ def apply_w8a8_block_fp8_linear(
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.T)
|
||||
if should_pad:
|
||||
output = output[:rows, :]
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
|
||||
Reference in New Issue
Block a user