[torch.compile] Cleanup compilation tests and custom passes, add debug utils, fix DCE bug (#23091), fix test (#24376), and prep for custom op matching (#24604) (#24542)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: luka <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2025-09-22 15:30:05 -04:00
committed by GitHub
parent 8d0ee5a564
commit d5e0fca264
24 changed files with 404 additions and 461 deletions

View File

@@ -1,15 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from torch import fx as fx
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import FusionPass
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass
if current_platform.is_cuda():
@@ -19,11 +25,28 @@ from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def with_pattern_match_debug(fn):
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
Used to avoid logging builtin Inductor pattern matching.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
return fn(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
class PostGradPassManager(CustomGraphPass):
"""
The pass manager for post-grad passes.
@@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass):
"""
def __init__(self):
self.passes: list[VllmInductorPass] = []
self.passes: list[InductorPass] = []
@with_pattern_match_debug
def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape):
pass_(graph)
VllmInductorPass.dump_prefix += 1
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# always run fix_functionalization last
self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
@@ -61,14 +94,18 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):