[Bugfix] Fix fusion for VL models (#30244)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
ElizaWszola
2025-12-14 14:22:37 +01:00
committed by GitHub
parent 48b8456ff9
commit 994acec0cc
4 changed files with 143 additions and 72 deletions

View File

@@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear_for_nk,
)
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
def __init__(
self,
epsilon: float,
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
self.model_dtype,
config.model_config.hf_config.intermediate_size,
config.model_config.hf_config.hidden_size,
)
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
@@ -142,7 +136,7 @@ class RMSNormQuantPattern:
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
@@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
@@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
@@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
scale = self.quant_matcher.make_scale(input, self.has_col_major_scales)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
is_scale_transposed=self.has_col_major_scales,
)
# result, residual, scale
@@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
@@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
@@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
input, transposed=self.quant_matcher.has_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
@@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
is_scale_transposed=self.quant_matcher.has_col_major_scales,
)
# result, scale
@@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
@@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log

View File

@@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp):
self,
quant_key: QuantKey,
enabled: bool | None = None,
use_col_major_scales: bool = False,
use_e8m0: bool = False,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.use_col_major_scales = use_col_major_scales
self.use_e8m0 = use_e8m0
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
self.quant_fp8 = QuantFP8(
quant_key.scale.static,
quant_key.scale.group_shape,
column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0,
)
def forward_custom(
self,
@@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp):
if self.quant_key.scale.group_shape.is_per_group():
assert scale is None
scale = self.make_scale(input, transposed=self.use_col_major_scales)
scale = self.make_scale(input, transposed=self.has_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
@@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp):
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.use_e8m0,
scale_ue8m0=self.is_e8m0,
)
return result, scale