[NVIDIA] Support nvfp4 quantization (#12784)

This commit is contained in:
Kaixi Hou
2025-02-12 19:51:51 -08:00
committed by GitHub
parent 9f9704dca6
commit 4fc5c23bb6
12 changed files with 688 additions and 13 deletions

View File

@@ -765,6 +765,63 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
return torch.ops._C.permute_cols(a, perm)
# fp4
def scaled_fp4_quant(
input: torch.Tensor,
input_global_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in the sizzled layout.
"""
assert input.ndim >= 1, (
f'input.ndim needs to be >= 1, but got {input.ndim}.')
other_dims = 1 if input.ndim == 1 else -1
input = input.reshape(other_dims, input.shape[-1])
m, n = input.shape
block_size = 16
device = input.device
assert n % block_size == 0, (
f'last dim has to be multiple of 16, but got {n}.')
assert input.dtype in (torch.float16, torch.bfloat16), (
f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.')
# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty((rounded_m, rounded_n // 4),
device=device,
dtype=torch.int32)
torch.ops._C.scaled_fp4_quant(output, input, output_scale,
input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale
# fp8
def scaled_fp8_quant(
input: torch.Tensor,

View File

@@ -321,6 +321,9 @@ class scalar_types:
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE)
# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)