[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
committed by
GitHub
parent
1ebdff412a
commit
fcb9df99bd
@@ -107,10 +107,14 @@ def test_flashinfer_nvfp4_gemm(
|
||||
# from checkpoints are in linear scales.
|
||||
# So instead of needing to swizzle for cutlass as in modelopt.py,
|
||||
# we need to unswizzle for trtllm here.
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
|
||||
a_dtype, a_global_scale, is_sf_swizzled_layout=True, backend=backend
|
||||
)
|
||||
is_sf_128x4_layout = not (backend == "trtllm" and m <= 32)
|
||||
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(
|
||||
b_dtype, b_global_scale, is_sf_swizzled_layout=True
|
||||
)
|
||||
|
||||
# get_ref_results unswizzles the scales internally.
|
||||
expected_out = get_ref_results(
|
||||
|
||||
Reference in New Issue
Block a user