[Bugfix] vLLM should check Inductor config for compile cache enablement status (#27637)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
gmagogsfm
2025-11-05 09:22:44 -08:00
committed by GitHub
parent 752ddeacaa
commit 002b07c4b2
2 changed files with 26 additions and 7 deletions

View File

@@ -163,6 +163,23 @@ def get_inductor_factors() -> list[Any]:
return factors
def is_compile_cache_enabled(
vllm_additional_inductor_config: dict[str, Any],
) -> bool:
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
"force_disable_caches", False
)
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
# with torch.compiler.config.force_disable_caches when minimum PyTorch
# version reaches 2.10
return (
not envs.VLLM_DISABLE_COMPILE_CACHE
and not torch._inductor.config.force_disable_caches
and not vllm_inductor_config_disable_cache
)
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
@@ -222,7 +239,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
if not envs.VLLM_DISABLE_COMPILE_CACHE:
if is_compile_cache_enabled(compiler_config):
compiled_graph.save(path=path, format=self.save_format)
compilation_counter.num_compiled_artifacts_saved += 1
return compiled_graph, (key, path)
@@ -472,10 +490,8 @@ class InductorAdaptor(CompilerInterface):
config_patches=current_config,
)
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
# compilation cache. So turn off the checks if we disable the
# compilation cache.
if not envs.VLLM_DISABLE_COMPILE_CACHE:
# Turn off the checks if we disable the compilation cache.
if is_compile_cache_enabled(compiler_config):
if hash_str is None:
raise RuntimeError(
"vLLM failed to compile the model. The most "