[Core] Rename PassConfig flags as per RFC #27995 (#29646)

Signed-off-by: arpitkh101 <arpit5khandelwal@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Arpit Khandelwal
2025-12-02 22:38:55 -05:00
committed by GitHub
parent 506ed87e87
commit d7284a2604
22 changed files with 318 additions and 123 deletions

View File

@@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.config.utils import config, handle_deprecated
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -105,18 +105,43 @@ class PassConfig:
improper state.
"""
# New flags
fuse_norm_quant: bool = Field(default=None)
"""Fuse the custom RMSNorm + quant ops."""
fuse_act_quant: bool = Field(default=None)
"""Fuse the custom SiluMul + quant ops."""
fuse_attn_quant: bool = Field(default=None)
"""Fuse the custom attention + quant ops."""
eliminate_noops: bool = Field(default=None)
"""Eliminate no-op ops."""
enable_sp: bool = Field(default=None)
"""Enable sequence parallelism."""
fuse_gemm_comms: bool = Field(default=None)
"""Enable async TP."""
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
# Deprecated flags
enable_fusion: bool = Field(default=None)
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
"""
enable_attn_fusion: bool = Field(default=None)
"""Whether to enable the custom attention+quant fusion pass."""
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_noop: bool = Field(default=None)
"""Whether to enable the custom no-op elimination pass."""
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_sequence_parallelism: bool = Field(default=None)
"""Whether to enable sequence parallelism."""
"""Deprecated in: v0.12.0. Use enable_sp instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_async_tp: bool = Field(default=None)
"""Whether to enable async TP."""
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_fi_allreduce_fusion: bool = Field(default=None)
"""Whether to enable flashinfer allreduce fusion."""
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
@@ -136,7 +161,7 @@ class PassConfig:
},
}, where key is the device capability"""
enable_qk_norm_rope_fusion: bool = False
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
"""Enable fused Q/K RMSNorm + RoPE pass."""
# TODO(luka) better pass enabling system.
@@ -174,6 +199,13 @@ class PassConfig:
return InductorPass.hash_dict(asdict(self))
@field_validator(
"fuse_norm_quant",
"fuse_act_quant",
"fuse_attn_quant",
"eliminate_noops",
"enable_sp",
"fuse_gemm_comms",
"fuse_allreduce_rms",
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
@@ -190,18 +222,71 @@ class PassConfig:
return handler(value)
def __post_init__(self) -> None:
if not self.enable_noop:
if self.enable_fusion:
# Handle deprecation and defaults
# Map old flags to new flags and issue warnings
handle_deprecated(
self,
"enable_fusion",
["fuse_norm_quant", "fuse_act_quant"],
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_attn_fusion",
"fuse_attn_quant",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_sequence_parallelism",
"enable_sp",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_async_tp",
"fuse_gemm_comms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_fi_allreduce_fusion",
"fuse_allreduce_rms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_noop",
"eliminate_noops",
"v0.13.0 or v1.0.0, whichever is sooner",
)
# Force old flags to None to ensure they are not used
self.enable_fusion = None
self.enable_attn_fusion = None
self.enable_noop = None
self.enable_sequence_parallelism = None
self.enable_async_tp = None
self.enable_fi_allreduce_fusion = None
if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm/SiluMul + quant (fp8) fusion might not work"
)
if self.enable_attn_fusion:
if self.fuse_attn_quant:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"Attention + quant (fp8) fusion might not work"
)
if self.enable_fi_allreduce_fusion:
if self.fuse_allreduce_rms:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
@@ -873,7 +958,7 @@ class CompilationConfig:
self.set_splitting_ops_for_inductor_graph_partition()
return
if self.pass_config.enable_attn_fusion:
if self.pass_config.fuse_attn_quant:
# here use_inductor_graph_partition is False
self.set_splitting_ops_for_attn_fusion()
return
@@ -915,12 +1000,12 @@ class CompilationConfig:
self.splitting_ops = list(self._attention_ops)
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion
assert self.pass_config.fuse_attn_quant
if self.splitting_ops is None:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"fuse_attn_quant is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
@@ -931,8 +1016,7 @@ class CompilationConfig:
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "
"when enable_attn_fusion is True"
"attention ops should not be in splitting_ops when fuse_attn_quant is True"
)
def splitting_ops_contain_attention(self) -> bool:
@@ -1008,7 +1092,7 @@ class CompilationConfig:
self, uniform_decode_query_len: int, tensor_parallel_size: int
):
multiple_of = uniform_decode_query_len
if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
if tensor_parallel_size > 1 and self.pass_config.enable_sp:
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
if (
multiple_of % uniform_decode_query_len != 0