[compile] Enable mega aot artifact for torch 2.12+. (#37198)
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
@@ -307,13 +307,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
num_submods = len(submod_names)
|
||||
num_artifacts = standalone_compile_artifacts.num_artifacts()
|
||||
|
||||
logger.info(
|
||||
"reconstructing serializable fn from standalone compile "
|
||||
"artifacts. num_artifacts=%d num_submods=%d",
|
||||
num_artifacts,
|
||||
num_submods,
|
||||
)
|
||||
|
||||
with functorch_ctx:
|
||||
fn = reconstruct_serializable_fn_from_mega_artifact(
|
||||
state=state,
|
||||
@@ -324,7 +317,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"reconstructed serializable fn from standalone compile artifacts"
|
||||
"reconstructed serializable fn from standalone compile "
|
||||
"artifacts. num_artifacts=%d num_submods=%d",
|
||||
num_artifacts,
|
||||
num_submods,
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
15
vllm/envs.py
15
vllm/envs.py
@@ -296,6 +296,16 @@ def use_aot_compile() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def use_mega_aot_artifact():
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
default_value = (
|
||||
"1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0"
|
||||
)
|
||||
|
||||
return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1"
|
||||
|
||||
|
||||
def env_with_choices(
|
||||
env_name: str,
|
||||
default: str | None,
|
||||
@@ -616,10 +626,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Enable loading compiled models directly from cached standalone compile artifacts
|
||||
# without re-splitting graph modules. This reduces overhead during model
|
||||
# loading by using reconstruct_serializable_fn_from_mega_artifact.
|
||||
"VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get(
|
||||
"VLLM_USE_MEGA_AOT_ARTIFACT", "0"
|
||||
)
|
||||
== "1",
|
||||
"VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact,
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
|
||||
|
||||
Reference in New Issue
Block a user