diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 00fb95921..2b667344f 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -307,13 +307,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] num_submods = len(submod_names) num_artifacts = standalone_compile_artifacts.num_artifacts() - logger.info( - "reconstructing serializable fn from standalone compile " - "artifacts. num_artifacts=%d num_submods=%d", - num_artifacts, - num_submods, - ) - with functorch_ctx: fn = reconstruct_serializable_fn_from_mega_artifact( state=state, @@ -324,7 +317,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ) logger.info( - "reconstructed serializable fn from standalone compile artifacts" + "reconstructed serializable fn from standalone compile " + "artifacts. num_artifacts=%d num_submods=%d", + num_artifacts, + num_submods, ) return fn diff --git a/vllm/envs.py b/vllm/envs.py index caa2fb38a..d6240df36 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -296,6 +296,16 @@ def use_aot_compile() -> bool: ) +def use_mega_aot_artifact(): + from vllm.utils.torch_utils import is_torch_equal_or_newer + + default_value = ( + "1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0" + ) + + return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1" + + def env_with_choices( env_name: str, default: str | None, @@ -616,10 +626,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Enable loading compiled models directly from cached standalone compile artifacts # without re-splitting graph modules. This reduces overhead during model # loading by using reconstruct_serializable_fn_from_mega_artifact. - "VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get( - "VLLM_USE_MEGA_AOT_ARTIFACT", "0" - ) - == "1", + "VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact, # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),