[AOT compilation] support torch.compile inductor artifacts in VllmCompiledFunction (#25205)
Signed-off-by: dolpm <34420038+dolpm@users.noreply.github.com>
This commit is contained in:
@@ -16,9 +16,12 @@ import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
@@ -230,12 +233,42 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config},
|
||||
)
|
||||
supports_aot = is_torch_equal_or_newer("2.10.0.dev")
|
||||
|
||||
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
logger.error(
|
||||
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
|
||||
"is enabled but PyTorch version does not support 'aot' "
|
||||
"parameter in standalone_compile. This requires PyTorch "
|
||||
"2.10.0+. Falling back to non-AOT mode."
|
||||
)
|
||||
|
||||
compile_kwargs = {
|
||||
"dynamic_shapes": dynamic_shapes,
|
||||
"options": {
|
||||
"config_patches": current_config,
|
||||
},
|
||||
}
|
||||
|
||||
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
|
||||
# only add 'aot' parameter if both supported and enabled...
|
||||
# this will set bundled_autograd_cache
|
||||
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
|
||||
if use_aot:
|
||||
compile_kwargs["aot"] = True # type: ignore[assignment]
|
||||
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
assert isinstance(compiled_graph, AOTCompiledArtifact)
|
||||
assert hasattr(compiled_graph, "serialize")
|
||||
# just return the compiled graph and a key
|
||||
# since we can serialize the bytes using to_bytes
|
||||
# and reload it using the key when reading
|
||||
return compiled_graph, None
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
@@ -619,7 +652,8 @@ def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
|
||||
|
||||
|
||||
def set_functorch_config() -> None:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
|
||||
Reference in New Issue
Block a user