[NVIDIA] Support nvfp4 quantization (#12784)
This commit is contained in:
@@ -765,6 +765,63 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.permute_cols(a, perm)
|
||||
|
||||
|
||||
# fp4
|
||||
def scaled_fp4_quant(
|
||||
input: torch.Tensor,
|
||||
input_global_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale.
|
||||
|
||||
This function quantizes the last dimension of the given tensor `input`. For
|
||||
every 16 consecutive elements, a single dynamically computed scaling factor
|
||||
is shared. This scaling factor is quantized using the `input_global_scale`
|
||||
and is stored in a swizzled layout (see
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
|
||||
two values are packed into a uint8 and float8_e4m3 scaling factors
|
||||
in the sizzled layout.
|
||||
"""
|
||||
assert input.ndim >= 1, (
|
||||
f'input.ndim needs to be >= 1, but got {input.ndim}.')
|
||||
other_dims = 1 if input.ndim == 1 else -1
|
||||
input = input.reshape(other_dims, input.shape[-1])
|
||||
m, n = input.shape
|
||||
block_size = 16
|
||||
device = input.device
|
||||
|
||||
assert n % block_size == 0, (
|
||||
f'last dim has to be multiple of 16, but got {n}.')
|
||||
assert input.dtype in (torch.float16, torch.bfloat16), (
|
||||
f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.')
|
||||
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
# We use the rounded values to store the swizzled values. Due to the
|
||||
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
||||
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
||||
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
torch.ops._C.scaled_fp4_quant(output, input, output_scale,
|
||||
input_global_scale)
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
# fp8
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
|
||||
@@ -321,6 +321,9 @@ class scalar_types:
|
||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||
|
||||
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
||||
|
||||
# "gptq" types
|
||||
uint2b2 = ScalarType.uint(2, 2)
|
||||
uint3b4 = ScalarType.uint(3, 4)
|
||||
|
||||
Reference in New Issue
Block a user