[Performance] Fused blockwise quant RMS norm (#27883)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -115,6 +115,12 @@ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
|
||||
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
|
||||
kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale)
|
||||
|
||||
kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
|
||||
kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
|
||||
|
||||
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
|
||||
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
|
||||
Reference in New Issue
Block a user