[torch.compile] Disable recursive pre_grad_passes (#34092)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -257,7 +257,20 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
if use_aot:
|
||||
compile_kwargs["aot"] = True # type: ignore[assignment]
|
||||
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
# Inductor's pre-grad passes don't do anything for vLLM.
|
||||
# The pre-grad passes get run even on cache-hit and negatively impact
|
||||
# vllm cold compile times by O(1s)
|
||||
# Can remove this after the following issue gets fixed
|
||||
# https://github.com/pytorch/pytorch/issues/174502
|
||||
if envs.VLLM_ENABLE_PREGRAD_PASSES:
|
||||
ctx: Any = contextlib.nullcontext()
|
||||
else:
|
||||
ctx = patch(
|
||||
"torch._inductor.compile_fx._recursive_pre_grad_passes",
|
||||
lambda gm, _: gm,
|
||||
)
|
||||
with ctx:
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
Reference in New Issue
Block a user