[Kernel][Performance] Enable smaller Scaling Factor tiling for NVFP4 small-batch decoding (#30885)

Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
Roberto L. Castro
2026-01-14 00:22:53 +01:00
committed by GitHub
parent 2a60ac91d0
commit 8ef50d9a6b
9 changed files with 177 additions and 32 deletions

View File

@@ -23,8 +23,26 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
return out[0:m, 0:k]
def convert_swizzled_8x4_layout_to_linear(
a_sf_swizzled: torch.Tensor, m, k, block_size
):
m_tiles = (m + 8 - 1) // 8
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4))
tmp = torch.permute(tmp, (0, 1, 3, 2, 4))
out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16,
is_sf_128x4_layout=True,
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
@@ -34,7 +52,11 @@ def dequantize_nvfp4_to_dtype(
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
if is_sf_128x4_layout:
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
else:
tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor