Sm100 blockwise fp8 swap ab (#18564)

This commit is contained in:
Lain
2025-06-04 07:48:45 -07:00
committed by GitHub
parent 02658c2dfe
commit 5f2cd251d2
3 changed files with 139 additions and 83 deletions

View File

@@ -136,24 +136,10 @@ def apply_w8a8_block_fp8_linear(
use_cutlass, use_aiter_and_is_supported)
if use_cutlass:
rows, cols = input_2d.shape
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
# optimal tensor core usage. Can be removed when targeting platforms
# without this constraint.
should_pad = current_platform.has_device_capability(
100) and rows % 4 != 0
if should_pad:
input_2d = torch.nn.functional.pad(input_2d,
(0, 0, 0, 4 - (rows % 4)),
value=0).contiguous()
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if should_pad:
output = output[:rows, :]
else:
q_input, x_scale = per_token_group_quant_fp8(