[torch.compile] Disable recursive pre_grad_passes (#34092)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-02-10 18:02:31 -05:00
committed by GitHub
parent 6f2f59f2b3
commit 341eed3d30
2 changed files with 24 additions and 1 deletions

View File

@@ -257,7 +257,20 @@ class InductorStandaloneAdaptor(CompilerInterface):
if use_aot:
compile_kwargs["aot"] = True # type: ignore[assignment]
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
# Inductor's pre-grad passes don't do anything for vLLM.
# The pre-grad passes get run even on cache-hit and negatively impact
# vllm cold compile times by O(1s)
# Can remove this after the following issue gets fixed
# https://github.com/pytorch/pytorch/issues/174502
if envs.VLLM_ENABLE_PREGRAD_PASSES:
ctx: Any = contextlib.nullcontext()
else:
ctx = patch(
"torch._inductor.compile_fx._recursive_pre_grad_passes",
lambda gm, _: gm,
)
with ctx:
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:
from torch._inductor.standalone_compile import AOTCompiledArtifact