Signed-off-by: arpitkh101 <arpit5khandelwal@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.config.utils import config
|
||||
from vllm.config.utils import config, handle_deprecated
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@@ -105,18 +105,43 @@ class PassConfig:
|
||||
improper state.
|
||||
"""
|
||||
|
||||
# New flags
|
||||
fuse_norm_quant: bool = Field(default=None)
|
||||
"""Fuse the custom RMSNorm + quant ops."""
|
||||
fuse_act_quant: bool = Field(default=None)
|
||||
"""Fuse the custom SiluMul + quant ops."""
|
||||
fuse_attn_quant: bool = Field(default=None)
|
||||
"""Fuse the custom attention + quant ops."""
|
||||
eliminate_noops: bool = Field(default=None)
|
||||
"""Eliminate no-op ops."""
|
||||
enable_sp: bool = Field(default=None)
|
||||
"""Enable sequence parallelism."""
|
||||
fuse_gemm_comms: bool = Field(default=None)
|
||||
"""Enable async TP."""
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
|
||||
# Deprecated flags
|
||||
enable_fusion: bool = Field(default=None)
|
||||
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
|
||||
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
|
||||
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
|
||||
"""
|
||||
enable_attn_fusion: bool = Field(default=None)
|
||||
"""Whether to enable the custom attention+quant fusion pass."""
|
||||
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_noop: bool = Field(default=None)
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_sequence_parallelism: bool = Field(default=None)
|
||||
"""Whether to enable sequence parallelism."""
|
||||
"""Deprecated in: v0.12.0. Use enable_sp instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_async_tp: bool = Field(default=None)
|
||||
"""Whether to enable async TP."""
|
||||
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_fi_allreduce_fusion: bool = Field(default=None)
|
||||
"""Whether to enable flashinfer allreduce fusion."""
|
||||
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
vllm should use flashinfer fused allreduce. Specified as a
|
||||
@@ -136,7 +161,7 @@ class PassConfig:
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
enable_qk_norm_rope_fusion: bool = False
|
||||
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
|
||||
"""Enable fused Q/K RMSNorm + RoPE pass."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
@@ -174,6 +199,13 @@ class PassConfig:
|
||||
return InductorPass.hash_dict(asdict(self))
|
||||
|
||||
@field_validator(
|
||||
"fuse_norm_quant",
|
||||
"fuse_act_quant",
|
||||
"fuse_attn_quant",
|
||||
"eliminate_noops",
|
||||
"enable_sp",
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
"enable_fusion",
|
||||
"enable_attn_fusion",
|
||||
"enable_noop",
|
||||
@@ -190,18 +222,71 @@ class PassConfig:
|
||||
return handler(value)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.enable_noop:
|
||||
if self.enable_fusion:
|
||||
# Handle deprecation and defaults
|
||||
|
||||
# Map old flags to new flags and issue warnings
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_fusion",
|
||||
["fuse_norm_quant", "fuse_act_quant"],
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_attn_fusion",
|
||||
"fuse_attn_quant",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_sequence_parallelism",
|
||||
"enable_sp",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_async_tp",
|
||||
"fuse_gemm_comms",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_fi_allreduce_fusion",
|
||||
"fuse_allreduce_rms",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_noop",
|
||||
"eliminate_noops",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
# Force old flags to None to ensure they are not used
|
||||
self.enable_fusion = None
|
||||
self.enable_attn_fusion = None
|
||||
self.enable_noop = None
|
||||
self.enable_sequence_parallelism = None
|
||||
self.enable_async_tp = None
|
||||
self.enable_fi_allreduce_fusion = None
|
||||
|
||||
if not self.eliminate_noops:
|
||||
if self.fuse_norm_quant or self.fuse_act_quant:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm/SiluMul + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_attn_fusion:
|
||||
if self.fuse_attn_quant:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Attention + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_fi_allreduce_fusion:
|
||||
if self.fuse_allreduce_rms:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||
@@ -873,7 +958,7 @@ class CompilationConfig:
|
||||
self.set_splitting_ops_for_inductor_graph_partition()
|
||||
return
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
# here use_inductor_graph_partition is False
|
||||
self.set_splitting_ops_for_attn_fusion()
|
||||
return
|
||||
@@ -915,12 +1000,12 @@ class CompilationConfig:
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
|
||||
def set_splitting_ops_for_attn_fusion(self):
|
||||
assert self.pass_config.enable_attn_fusion
|
||||
assert self.pass_config.fuse_attn_quant
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"fuse_attn_quant is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off. "
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
@@ -931,8 +1016,7 @@ class CompilationConfig:
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
assert not self.splitting_ops_contain_attention(), (
|
||||
"attention ops should not be in splitting_ops "
|
||||
"when enable_attn_fusion is True"
|
||||
"attention ops should not be in splitting_ops when fuse_attn_quant is True"
|
||||
)
|
||||
|
||||
def splitting_ops_contain_attention(self) -> bool:
|
||||
@@ -1008,7 +1092,7 @@ class CompilationConfig:
|
||||
self, uniform_decode_query_len: int, tensor_parallel_size: int
|
||||
):
|
||||
multiple_of = uniform_decode_query_len
|
||||
if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
|
||||
if tensor_parallel_size > 1 and self.pass_config.enable_sp:
|
||||
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
|
||||
if (
|
||||
multiple_of % uniform_decode_query_len != 0
|
||||
|
||||
Reference in New Issue
Block a user