[torch.compile] caching of config fields should be opt-out by default (#26468)
Signed-off-by: vnadathur <glvikramn@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: WorldExplored <srreyansh.sethi@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+worldexplored@users.noreply.github.com> Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -4,12 +4,14 @@
|
||||
import ast
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import operator
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -23,7 +25,9 @@ from vllm.compilation.partition_rules import (
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
@@ -580,35 +584,47 @@ class VllmBackend:
|
||||
def __call__(
|
||||
self, graph: fx.GraphModule, example_inputs
|
||||
) -> VllmSerializableFunction:
|
||||
from .caching import _compute_code_hash, compilation_config_hash_factors
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
# Minimal hashing here with existing utilities, reused below.
|
||||
|
||||
env_factors = envs.compile_factors()
|
||||
env_hash = hash_factors(env_factors)
|
||||
# Compute config/compiler/code hashes once and reuse
|
||||
config_hash = vllm_config.compute_hash()
|
||||
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||
forward_code_files = list(sorted(self.compilation_config.traced_files))
|
||||
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s",
|
||||
lazy(lambda: "\n".join(forward_code_files)),
|
||||
)
|
||||
hash_content = []
|
||||
for filepath in forward_code_files:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
try:
|
||||
with open(filepath) as f:
|
||||
hash_content.append(f.read())
|
||||
except Exception:
|
||||
logger.warning("Failed to read file %s", filepath)
|
||||
continue
|
||||
code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
|
||||
# Clear after consumption
|
||||
self.compilation_config.traced_files.clear()
|
||||
if not self.compilation_config.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
# that affects the compilation. if none of the factors change,
|
||||
# the cache dir will be the same so that we can reuse the compiled
|
||||
# graph.
|
||||
|
||||
factors = compilation_config_hash_factors(vllm_config)
|
||||
# 2. factors come from the code files that are traced by Dynamo (
|
||||
# it mainly summarizes how the model is used in forward pass)
|
||||
code_hash = _compute_code_hash(self.compilation_config.traced_files)
|
||||
self.compilation_config.traced_files.clear()
|
||||
factors.append(code_hash)
|
||||
|
||||
# 3. compiler hash
|
||||
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
|
||||
factors.append(compiler_hash)
|
||||
|
||||
# combine all factors to generate the cache dir
|
||||
hash_key = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
|
||||
factors = [env_hash, config_hash, code_hash, compiler_hash]
|
||||
# Use SHA-256 for cache key hashing to be consistent across
|
||||
# compute_hash functions. Truncate for a short cache dir name.
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
hash_key,
|
||||
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
|
||||
)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
|
||||
@@ -621,6 +637,7 @@ class VllmBackend:
|
||||
os.makedirs(local_cache_dir, exist_ok=True)
|
||||
self.compilation_config.local_cache_dir = local_cache_dir
|
||||
|
||||
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
|
||||
disable_cache = not is_compile_cache_enabled(
|
||||
self.compilation_config.inductor_compile_config
|
||||
)
|
||||
@@ -638,6 +655,50 @@ class VllmBackend:
|
||||
local_cache_dir, disable_cache, self.prefix
|
||||
)
|
||||
|
||||
# Reuses existing cache key
|
||||
|
||||
logger.debug(
|
||||
"torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
|
||||
env_hash,
|
||||
config_hash,
|
||||
compiler_hash,
|
||||
code_hash,
|
||||
local_cache_dir,
|
||||
)
|
||||
|
||||
# Persist and log only hash-relevant factors together.
|
||||
try:
|
||||
logger.debug(
|
||||
"Compile env factors (raw):\n%s\nVllm config hash: %s",
|
||||
lazy(partial(pprint.pformat, env_factors, width=120)),
|
||||
config_hash,
|
||||
)
|
||||
meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
|
||||
if not os.path.exists(meta_path):
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"env": env_factors, # raw factors used for env_hash
|
||||
"config_hash": config_hash,
|
||||
"code_hash": code_hash,
|
||||
"compiler_hash": compiler_hash,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort only; metadata write failures are non-fatal.
|
||||
logger.warning(
|
||||
(
|
||||
"Could not write compile cache metadata at %s; continuing without "
|
||||
"metadata. Compiled cache remains valid; diagnostics may be "
|
||||
"limited."
|
||||
),
|
||||
local_cache_dir,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
|
||||
@@ -127,7 +127,7 @@ class PostGradPassManager(CustomGraphPass):
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
||||
state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
|
||||
Reference in New Issue
Block a user