[Performance] Fused blockwise quant RMS norm (#27883)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
ElizaWszola
2025-12-07 17:38:04 +01:00
committed by GitHub
parent 0044c4038c
commit af0444bf40
14 changed files with 949 additions and 157 deletions

View File

@@ -733,7 +733,7 @@ def per_token_group_quant_fp8(
assert out_q is None or out_q.shape == x.shape
x_q = out_q
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_q = torch.empty(x.shape, device=x.device, dtype=dtype)
# Allocate the scale tensor in either row- or column-major format.
if column_major_scales:

View File

@@ -115,6 +115,12 @@ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale)
kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):