[AOT compilation] support torch.compile inductor artifacts in VllmCompiledFunction (#25205)

Signed-off-by: dolpm <34420038+dolpm@users.noreply.github.com>
This commit is contained in:
dolpm
2026-01-20 11:45:59 -08:00
committed by GitHub
parent 193069d129
commit 7c5dedc247
8 changed files with 1169 additions and 113 deletions

View File

@@ -16,9 +16,12 @@ import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
class CompilerInterface:
"""
@@ -230,12 +233,42 @@ class InductorStandaloneAdaptor(CompilerInterface):
from torch._inductor import standalone_compile
compiled_graph = standalone_compile(
graph,
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config},
)
supports_aot = is_torch_equal_or_newer("2.10.0.dev")
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
logger.error(
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
"is enabled but PyTorch version does not support 'aot' "
"parameter in standalone_compile. This requires PyTorch "
"2.10.0+. Falling back to non-AOT mode."
)
compile_kwargs = {
"dynamic_shapes": dynamic_shapes,
"options": {
"config_patches": current_config,
},
}
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
# only add 'aot' parameter if both supported and enabled...
# this will set bundled_autograd_cache
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
if use_aot:
compile_kwargs["aot"] = True # type: ignore[assignment]
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:
from torch._inductor.standalone_compile import AOTCompiledArtifact
assert isinstance(compiled_graph, AOTCompiledArtifact)
assert hasattr(compiled_graph, "serialize")
# just return the compiled graph and a key
# since we can serialize the bytes using to_bytes
# and reload it using the key when reading
return compiled_graph, None
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
@@ -619,7 +652,8 @@ def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
def set_functorch_config() -> None:
torch._functorch.config.bundled_autograd_cache = False
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
torch._functorch.config.bundled_autograd_cache = False
class EagerAdaptor(CompilerInterface):