[torch.compile] Cleanup compilation tests and custom passes, add debug utils, fix DCE bug (#23091), fix test (#24376), and prep for custom op matching (#24604) (#24542)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: luka <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2025-09-22 15:30:05 -04:00
committed by GitHub
parent 8d0ee5a564
commit d5e0fca264
24 changed files with 404 additions and 461 deletions

View File

@@ -15,7 +15,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@@ -417,7 +417,7 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
pm.fwd_only, pm_pass)
class SequenceParallelismPass(VllmInductorPass):
class SequenceParallelismPass(VllmPatternMatcherPass):
"""
This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by
@@ -466,19 +466,13 @@ class SequenceParallelismPass(VllmInductorPass):
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_sequence_parallelism_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns with sequence parallelism", count)
self.dump_graph(graph, "after_sequence_parallelism_pass")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)