[Kernel][Performance] Enable smaller Scaling Factor tiling for NVFP4 small-batch decoding (#30885)

Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
Roberto L. Castro
2026-01-14 00:22:53 +01:00
committed by GitHub
parent 2a60ac91d0
commit 8ef50d9a6b
9 changed files with 177 additions and 32 deletions

View File

@@ -23,8 +23,26 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
return out[0:m, 0:k]
def convert_swizzled_8x4_layout_to_linear(
a_sf_swizzled: torch.Tensor, m, k, block_size
):
m_tiles = (m + 8 - 1) // 8
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4))
tmp = torch.permute(tmp, (0, 1, 3, 2, 4))
out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16,
is_sf_128x4_layout=True,
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
@@ -34,7 +52,11 @@ def dequantize_nvfp4_to_dtype(
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
if is_sf_128x4_layout:
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
else:
tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor

View File

@@ -11,7 +11,9 @@ from nvfp4_utils import (
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
)
from vllm.utils.torch_utils import set_random_seed
if not current_platform.has_device_capability(100):
@@ -22,8 +24,14 @@ if not current_platform.has_device_capability(100):
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES = [
(128, 128, 64),
(128, 128, 128),
(256, 128, 64),
(128, 256, 128),
(1, 128, 128),
]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96), (2, 128, 64), (3, 128, 96)]
SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
@@ -42,12 +50,19 @@ def get_ref_results(
dtype,
block_size,
device,
is_sf_128x4_layout,
):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size,
is_sf_128x4_layout=is_sf_128x4_layout,
)
b_in_dtype = dequantize_nvfp4_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
@@ -70,7 +85,7 @@ def test_flashinfer_nvfp4_gemm(
backend: str,
autotune: bool,
) -> None:
if backend == "trtllm" and dtype == torch.float16:
if "trtllm" in backend and dtype == torch.float16:
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
set_random_seed(seed)
@@ -87,11 +102,14 @@ def test_flashinfer_nvfp4_gemm(
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend)
is_sf_128x4_layout = not (backend == "trtllm" and m <= 32)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
# get_ref_results unswizzles the scales internally.
@@ -107,14 +125,14 @@ def test_flashinfer_nvfp4_gemm(
dtype,
block_size,
device,
is_sf_128x4_layout,
)
import flashinfer
if backend == "trtllm":
if "trtllm" in backend:
epilogue_tile_m = 128
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m)
b_scale_interleaved = convert_swizzled_to_linear(
b_scale_interleaved, n, k, block_size
)