Sm100 blockwise fp8 swap ab (#18564)
This commit is contained in:
@@ -136,24 +136,10 @@ def apply_w8a8_block_fp8_linear(
|
||||
use_cutlass, use_aiter_and_is_supported)
|
||||
|
||||
if use_cutlass:
|
||||
rows, cols = input_2d.shape
|
||||
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
|
||||
# optimal tensor core usage. Can be removed when targeting platforms
|
||||
# without this constraint.
|
||||
should_pad = current_platform.has_device_capability(
|
||||
100) and rows % 4 != 0
|
||||
if should_pad:
|
||||
input_2d = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, 4 - (rows % 4)),
|
||||
value=0).contiguous()
|
||||
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
if should_pad:
|
||||
output = output[:rows, :]
|
||||
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
|
||||
Reference in New Issue
Block a user