[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
Roberto L. Castro
2026-01-25 02:45:27 +01:00
committed by GitHub
parent 1ebdff412a
commit fcb9df99bd
18 changed files with 508 additions and 151 deletions

View File

@@ -27,6 +27,12 @@ PAD_SHAPES = [
(150, 128),
(150, 48),
(90, 80),
(128, 512),
(128, 1024),
(128, 2048),
(64, 7168),
(64, 7152),
(32, 14336),
]
SEEDS = [42]
CUDA_DEVICES = ["cuda:0"]
@@ -173,3 +179,25 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref)
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
@torch.inference_mode()
def test_quantize_to_fp4_padded_no_sf_swizzled(pad_shape: tuple[int, int]) -> None:
dtype = torch.float16
set_random_seed(42)
torch.set_default_device("cuda:0")
m, n = pad_shape
x = torch.randn((m, n), dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
out, out_scale = ops.scaled_fp4_quant(x, global_scale, is_sf_swizzled_layout=False)
scale_ans = out_scale.to(torch.float32)
out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref)