[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
committed by
GitHub
parent
1ebdff412a
commit
fcb9df99bd
@@ -107,10 +107,14 @@ def test_flashinfer_nvfp4_gemm(
|
||||
# from checkpoints are in linear scales.
|
||||
# So instead of needing to swizzle for cutlass as in modelopt.py,
|
||||
# we need to unswizzle for trtllm here.
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
|
||||
a_dtype, a_global_scale, is_sf_swizzled_layout=True, backend=backend
|
||||
)
|
||||
is_sf_128x4_layout = not (backend == "trtllm" and m <= 32)
|
||||
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(
|
||||
b_dtype, b_global_scale, is_sf_swizzled_layout=True
|
||||
)
|
||||
|
||||
# get_ref_results unswizzles the scales internally.
|
||||
expected_out = get_ref_results(
|
||||
|
||||
@@ -27,6 +27,12 @@ PAD_SHAPES = [
|
||||
(150, 128),
|
||||
(150, 48),
|
||||
(90, 80),
|
||||
(128, 512),
|
||||
(128, 1024),
|
||||
(128, 2048),
|
||||
(64, 7168),
|
||||
(64, 7152),
|
||||
(32, 14336),
|
||||
]
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
@@ -173,3 +179,25 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
out_ans = cast_from_fp4(out, m, n)
|
||||
torch.testing.assert_close(out_ans, out_ref)
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4_padded_no_sf_swizzled(pad_shape: tuple[int, int]) -> None:
|
||||
dtype = torch.float16
|
||||
set_random_seed(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
m, n = pad_shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||
|
||||
out, out_scale = ops.scaled_fp4_quant(x, global_scale, is_sf_swizzled_layout=False)
|
||||
scale_ans = out_scale.to(torch.float32)
|
||||
out_ans = cast_from_fp4(out, m, n)
|
||||
torch.testing.assert_close(out_ans, out_ref)
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
Reference in New Issue
Block a user