[Compile] Conditional compilation. Introduce compile_ranges (#24252)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
Ilya Markov
2025-12-05 19:17:32 +01:00
committed by GitHub
parent 66e674cdd5
commit 4e26d3b09e
15 changed files with 582 additions and 268 deletions

View File

@@ -24,7 +24,11 @@ if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
@@ -70,13 +74,13 @@ class PostGradPassManager(CustomGraphPass):
def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape
compile_range = get_pass_context().compile_range
for pass_ in self.passes:
if pass_.is_applicable(shape):
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with shape %s", pass_, shape)
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
@@ -133,4 +137,8 @@ class PostGradPassManager(CustomGraphPass):
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
return InductorPass.hash_dict(state)