[Kernel][Performance] Enable smaller Scaling Factor tiling for NVFP4 small-batch decoding (#30885)
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
committed by
GitHub
parent
2a60ac91d0
commit
8ef50d9a6b
@@ -23,8 +23,26 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def convert_swizzled_8x4_layout_to_linear(
|
||||
a_sf_swizzled: torch.Tensor, m, k, block_size
|
||||
):
|
||||
m_tiles = (m + 8 - 1) // 8
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 3, 2, 4))
|
||||
out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_nvfp4_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
tensor_fp4,
|
||||
tensor_sf,
|
||||
global_scale,
|
||||
dtype,
|
||||
device,
|
||||
block_size=16,
|
||||
is_sf_128x4_layout=True,
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
@@ -34,7 +52,11 @@ def dequantize_nvfp4_to_dtype(
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
if is_sf_128x4_layout:
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
else:
|
||||
tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size)
|
||||
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
|
||||
Reference in New Issue
Block a user