diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 1d5adb185..c00486af6 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -257,7 +257,20 @@ class InductorStandaloneAdaptor(CompilerInterface): if use_aot: compile_kwargs["aot"] = True # type: ignore[assignment] - compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs) + # Inductor's pre-grad passes don't do anything for vLLM. + # The pre-grad passes get run even on cache-hit and negatively impact + # vllm cold compile times by O(1s) + # Can remove this after the following issue gets fixed + # https://github.com/pytorch/pytorch/issues/174502 + if envs.VLLM_ENABLE_PREGRAD_PASSES: + ctx: Any = contextlib.nullcontext() + else: + ctx = patch( + "torch._inductor.compile_fx._recursive_pre_grad_passes", + lambda gm, _: gm, + ) + with ctx: + compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs) if use_aot: from torch._inductor.standalone_compile import AOTCompiledArtifact diff --git a/vllm/envs.py b/vllm/envs.py index 314f42758..039b3239c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -132,6 +132,7 @@ if TYPE_CHECKING: VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 VLLM_USE_STANDALONE_COMPILE: bool = True + VLLM_ENABLE_PREGRAD_PASSES: bool = False VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 @@ -568,6 +569,15 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_STANDALONE_COMPILE", "1" ) == "1", + # Inductor's pre-grad passes don't do anything for vLLM. + # The pre-grad passes get run even on cache-hit and negatively impact + # vllm cold compile times by O(1s) + # Can remove this after the following issue gets fixed + # https://github.com/pytorch/pytorch/issues/174502 + "VLLM_ENABLE_PREGRAD_PASSES": lambda: os.environ.get( + "VLLM_ENABLE_PREGRAD_PASSES", "0" + ) + == "1", # Debug pattern matching inside custom passes. # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get(