[compile] Enable mega aot artifact for torch 2.12+. (#37198)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen
2026-03-16 17:05:51 -04:00
committed by GitHub
parent 2dccb38f73
commit e6ae4b1be1
2 changed files with 15 additions and 12 deletions

View File

@@ -307,13 +307,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
num_submods = len(submod_names)
num_artifacts = standalone_compile_artifacts.num_artifacts()
logger.info(
"reconstructing serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
)
with functorch_ctx:
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
@@ -324,7 +317,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
)
logger.info(
"reconstructed serializable fn from standalone compile artifacts"
"reconstructed serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
)
return fn

View File

@@ -296,6 +296,16 @@ def use_aot_compile() -> bool:
)
def use_mega_aot_artifact():
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = (
"1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0"
)
return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1"
def env_with_choices(
env_name: str,
default: str | None,
@@ -616,10 +626,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Enable loading compiled models directly from cached standalone compile artifacts
# without re-splitting graph modules. This reduces overhead during model
# loading by using reconstruct_serializable_fn_from_mega_artifact.
"VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get(
"VLLM_USE_MEGA_AOT_ARTIFACT", "0"
)
== "1",
"VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact,
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),