[Performance] Fused blockwise quant RMS norm (#27883)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
ElizaWszola
2025-12-07 17:38:04 +01:00
committed by GitHub
parent 0044c4038c
commit af0444bf40
14 changed files with 949 additions and 157 deletions

View File

@@ -15,13 +15,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear_for_nk,
)
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
@@ -58,6 +67,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
@@ -90,6 +102,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(
kFp8DynamicTokenSym, True
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
}
@@ -100,6 +124,15 @@ class RMSNormQuantPattern:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
self.model_dtype,
config.model_config.hf_config.intermediate_size,
config.model_config.hf_config.hidden_size,
)
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
@@ -108,7 +141,9 @@ class RMSNormQuantPattern:
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(key.quant)
self.quant_matcher = MatcherQuantFP8(
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
@@ -218,6 +253,120 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
)
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
)
# result, residual, scale
return at[1], at[3], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(input: torch.Tensor, weight: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
)
# result, scale
return at[1], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
@@ -340,6 +489,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
@@ -366,9 +534,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
def uuid(self) -> Any:
return self.hash_source(
self,
RMSNormGroupQuantPattern,
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
FusedAddRMSNormGroupQuantPattern,
)

View File

@@ -13,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
_normalize_quant_group_shape,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
@@ -35,6 +37,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
@@ -224,12 +230,20 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
class MatcherQuantFP8(MatcherCustomOp):
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
def __init__(
self,
quant_key: QuantKey,
enabled: bool | None = None,
use_col_major_scales: bool = False,
use_e8m0: bool = False,
):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.use_col_major_scales = use_col_major_scales
self.use_e8m0 = use_e8m0
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
@@ -248,6 +262,27 @@ class MatcherQuantFP8(MatcherCustomOp):
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.group_shape.is_per_group():
assert scale is None
scale = self.make_scale(input, transposed=self.use_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.QUANT_OP,
input=input,
output_q=result,
output_s=scale,
group_size=self.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.use_e8m0,
)
return result, scale
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
@@ -269,7 +304,7 @@ class MatcherQuantFP8(MatcherCustomOp):
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
def make_scale(self, input: torch.Tensor):
def make_scale(self, input: torch.Tensor, transposed: bool = False):
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
@@ -277,6 +312,11 @@ class MatcherQuantFP8(MatcherCustomOp):
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
if transposed:
scale_shape = tuple(reversed(scale_shape))
return torch.empty(
scale_shape, device=input.device, dtype=torch.float32
).permute(-1, -2)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)