[torch.compile] fast inductor (#11108)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
youkaichao
2024-12-16 16:15:22 -08:00
committed by GitHub
parent c301616ed2
commit 88a412ed3d
3 changed files with 624 additions and 7 deletions

View File

@@ -1,6 +1,10 @@
import ast
import copy import copy
import dataclasses import dataclasses
import os
import pprint
import time import time
from collections import defaultdict
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch from unittest.mock import patch
@@ -21,6 +25,122 @@ from .pass_manager import PostGradPassManager
logger = init_logger(__name__) logger = init_logger(__name__)
class InductorHashCache:
"""
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str)
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
we don't use json here because json doesn't support int as key.
TODO: better off-the-shelf solution to serialize the data?
"""
def __init__(self, cache_dir: str, disabled: bool = False):
self.cache: defaultdict = defaultdict(dict)
self.disabled = disabled
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir,
"inductor_hash_cache.py")
if disabled:
return
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
if os.path.exists(self.cache_file_path):
with open(self.cache_file_path) as f:
self.deserialize(f.read())
def deserialize(self, data: str):
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
for runtime_shape, graph_index, hash_str in list_data:
self.cache[runtime_shape][graph_index] = hash_str
def serialize(self) -> str:
data = []
for runtime_shape, graph_index_to_hash_str in self.cache.items():
for graph_index, hash_str in graph_index_to_hash_str.items():
data.append((runtime_shape, graph_index, hash_str))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)
def save_to_file(self):
if self.disabled:
return
with open(self.cache_file_path, "w") as f:
f.write(self.serialize())
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
if self.disabled:
return False
runtime_shape, graph_index = key
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]
def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
if self.disabled:
raise KeyError("cannot read from disabled cache")
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]
def __setitem__(self, key: Tuple[Optional[int], int], value: str):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
self.cache[runtime_shape][graph_index] = value
class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def __init__(self) -> None:
self.guards: List[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
return True
def get_pruned_guards(self, *args, **kwargs):
return []
def produce_guards_expression(self, *args, **kwargs):
return ""
def wrap_inductor(graph, def wrap_inductor(graph,
example_inputs, example_inputs,
additional_inductor_config, additional_inductor_config,
@@ -55,6 +175,90 @@ def wrap_inductor(graph,
# inductor can inplace modify the graph, so we need to copy it # inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980 # see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph) graph = copy.deepcopy(graph)
cache_data = compilation_config.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
hash_str = cache_data[(runtime_shape, graph_index)]
if graph_index == 0:
# adds some info logging for the first graph
logger.info(
"Directly lookup the graph for shape %s from the cache",
str(runtime_shape)) # noqa
logger.debug(
"directly lookup the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), hash_str)
from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
)
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the graph we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
else:
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
from torch._inductor.codecache import compiled_fx_graph_hash
def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
# store the hash in the cache
nonlocal cache_data
cache_data[(runtime_shape, graph_index)] = out[0]
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug("store the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), out[0])
return out
def _check_can_cache(*args, **kwargs):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def _get_shape_env():
return AlwaysHitShapeEnv()
with patch(# for hijacking the hash of the compiled graph
"torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash), \
patch(# for providing a dummy shape environment
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env), \
patch(# for forcing the graph to be cached
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache):
compiled_graph = compile_fx(graph, compiled_graph = compile_fx(graph,
example_inputs, example_inputs,
config_patches=current_config) config_patches=current_config)
@@ -457,6 +661,9 @@ class PiecewiseBackend:
# finished compilations for all required shapes # finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config) end_monitoring_torch_compile(self.vllm_config)
if not entry.use_cudagraph: if not entry.use_cudagraph:

View File

@@ -3,6 +3,7 @@ import copy
import enum import enum
import hashlib import hashlib
import json import json
import os
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
@@ -162,6 +163,30 @@ class ModelConfig:
which allows no processors. which allows no processors.
""" """
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors.append(self.model)
factors.append(self.dtype)
factors.append(self.quantization)
factors.append(self.quantization_param_path)
factors.append(self.revision)
factors.append(self.code_revision)
factors.append(self.trust_remote_code)
factors.append(self.rope_scaling)
factors.append(self.rope_theta)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __init__(self, def __init__(self,
model: str, model: str,
task: Union[TaskOption, Literal["draft"]], task: Union[TaskOption, Literal["draft"]],
@@ -203,6 +228,8 @@ class ModelConfig:
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.code_revision = code_revision self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
if hf_overrides is None: if hf_overrides is None:
hf_overrides = {} hf_overrides = {}
@@ -832,6 +859,24 @@ class CacheConfig:
cpu_offload_gb: Size of the CPU offload buffer in GiB. cpu_offload_gb: Size of the CPU offload buffer in GiB.
""" """
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors.append(self.cache_dtype)
# `cpu_offload_gb` does not use `torch.compile` yet.
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __init__( def __init__(
self, self,
block_size: int, block_size: int,
@@ -928,6 +973,24 @@ class TokenizerPoolConfig:
pool_type: Union[str, Type["BaseTokenizerGroup"]] pool_type: Union[str, Type["BaseTokenizerGroup"]]
extra_config: dict extra_config: dict
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self): def __post_init__(self):
if self.pool_type not in ("ray", ) and not isinstance( if self.pool_type not in ("ray", ) and not isinstance(
self.pool_type, type): self.pool_type, type):
@@ -1010,6 +1073,24 @@ class LoadConfig:
default_factory=dict) default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None ignore_patterns: Optional[Union[List[str], str]] = None
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str): if isinstance(model_loader_extra_config, str):
@@ -1073,6 +1154,19 @@ class ParallelConfig:
rank: int = 0 rank: int = 0
def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors.append(self.pipeline_parallel_size)
factors.append(self.tensor_parallel_size)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \ self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size self.tensor_parallel_size
@@ -1209,6 +1303,24 @@ class SchedulerConfig:
chunked_prefill_enabled: bool = field(init=False) chunked_prefill_enabled: bool = field(init=False)
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.max_num_batched_tokens is None: if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill: if self.enable_chunked_prefill:
@@ -1286,6 +1398,25 @@ class DeviceConfig:
device: Optional[torch.device] device: Optional[torch.device]
device_type: str device_type: str
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __init__(self, device: str = "auto") -> None: def __init__(self, device: str = "auto") -> None:
if device == "auto": if device == "auto":
# Automated device type detection # Automated device type detection
@@ -1313,6 +1444,24 @@ class SpeculativeConfig:
decoding with top-1 proposals. decoding with top-1 proposals.
""" """
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# spec decode does not use `torch.compile` yet.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@staticmethod @staticmethod
def maybe_create_spec_config( def maybe_create_spec_config(
target_model_config: ModelConfig, target_model_config: ModelConfig,
@@ -1753,6 +1902,24 @@ class LoRAConfig:
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
bias_enabled: bool = False bias_enabled: bool = False
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# LoRA is not compatible with `torch.compile` .
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self): def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast # Setting the maximum rank to 256 should be able to satisfy the vast
# majority of applications. # majority of applications.
@@ -1802,6 +1969,24 @@ class PromptAdapterConfig:
max_cpu_prompt_adapters: Optional[int] = None max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None prompt_adapter_dtype: Optional[torch.dtype] = None
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self): def __post_init__(self):
if self.max_prompt_adapters < 1: if self.max_prompt_adapters < 1:
@@ -1830,6 +2015,24 @@ class MultiModalConfig:
for each :class:`~vllm.multimodal.MultiModalPlugin`. for each :class:`~vllm.multimodal.MultiModalPlugin`.
""" """
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
# TODO: Add configs to init vision tower or not. # TODO: Add configs to init vision tower or not.
@@ -1869,6 +2072,24 @@ class PoolerConfig:
``math-shepherd-mistral-7b-prm`` model. ``math-shepherd-mistral-7b-prm`` model.
""" """
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@staticmethod @staticmethod
def from_json(json_str: str) -> "PoolerConfig": def from_json(json_str: str) -> "PoolerConfig":
return PoolerConfig(**json.loads(json_str)) return PoolerConfig(**json.loads(json_str))
@@ -2103,6 +2324,24 @@ class DecodingConfig:
# 'outlines' / 'lm-format-enforcer' / 'xgrammar' # 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar' guided_decoding_backend: str = 'xgrammar'
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self): def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend backend = self.guided_decoding_backend
@@ -2124,6 +2363,24 @@ class ObservabilityConfig:
# If set, collects the model execute time for the request. # If set, collects the model execute time for the request.
collect_model_execute_time: bool = False collect_model_execute_time: bool = False
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self): def __post_init__(self):
if not is_otel_available() and self.otlp_traces_endpoint is not None: if not is_otel_available() and self.otlp_traces_endpoint is not None:
raise ValueError( raise ValueError(
@@ -2165,6 +2422,24 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection # The KV connector port, used to build distributed connection
kv_port: int = 14579 kv_port: int = 14579
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@classmethod @classmethod
def from_cli(cls, cli_value: str) -> "KVTransferConfig": def from_cli(cls, cli_value: str) -> "KVTransferConfig":
"""Parse the CLI value for the kv cache transfer config.""" """Parse the CLI value for the kv cache transfer config."""
@@ -2234,6 +2509,9 @@ class CompilationConfig(BaseModel):
- 2: dynamo once. - 2: dynamo once.
- 3: piecewise compilation. - 3: piecewise compilation.
- debug_dump_path: the path to dump the debug information. - debug_dump_path: the path to dump the debug information.
- cache_dir: the directory to store the compiled graph, to
accelerate Inductor compilation. By default, it will use
model-related information to generate a cache directory.
- backend: the backend for compilation. It needs to be a string. - backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend. - "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch. - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
@@ -2302,12 +2580,10 @@ class CompilationConfig(BaseModel):
""" # noqa """ # noqa
level: int = 0 level: int = 0
debug_dump_path: str = "" debug_dump_path: str = ""
cache_dir: str = ""
backend: str = "" backend: str = ""
custom_ops: List[str] = Field(default_factory=list) custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [ splitting_ops: List[str] = Field(default=None) # type: ignore
"vllm.unified_attention",
"vllm.unified_attention_with_output",
])
use_inductor: bool = True use_inductor: bool = True
candidate_compile_sizes: Optional[List[int]] = Field(default=None) candidate_compile_sizes: Optional[List[int]] = Field(default=None)
@@ -2371,12 +2647,37 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr
compilation_time: float = PrivateAttr compilation_time: float = PrivateAttr
# should be InductorHashCache, but Pydantic does not support it
inductor_hash_cache: Any = PrivateAttr
# Per-model forward context # Per-model forward context
# Mainly used to store attention cls # Mainly used to store attention cls
# Map from layer name to the attention cls # Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr static_forward_context: Dict[str, Any] = PrivateAttr
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors.append(self.level)
factors.append(self.backend)
factors.append(self.custom_ops)
factors.append(self.splitting_ops)
factors.append(self.use_inductor)
factors.append(self.inductor_compile_config)
factors.append(self.inductor_passes)
factors.append(self.pass_config.uuid())
return hashlib.sha256(str(factors).encode()).hexdigest()
def __repr__(self) -> str: def __repr__(self) -> str:
exclude = { exclude = {
"static_forward_context", "static_forward_context",
@@ -2405,6 +2706,27 @@ class CompilationConfig(BaseModel):
count_all = self.custom_ops.count("all") count_all = self.custom_ops.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
if self.splitting_ops is None:
if envs.VLLM_USE_V1:
# v1 must split the graph on attention ops
# for piecewise cudagraph
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
else:
# v0 can use full graph compilation without splitting,
# splitting is optional.
# right now we still need it. kv cache shape
# will be included in the graph if we don't split
# the graph.
# TODO: hide kv cache in static forward context
# so that inductor does not see it.
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
for k, v in self.inductor_passes.items(): for k, v in self.inductor_passes.items():
if not isinstance(v, str): if not isinstance(v, str):
assert callable(v), ( assert callable(v), (
@@ -2444,6 +2766,30 @@ class CompilationConfig(BaseModel):
# TODO: pass user-specified backend to piecewise compilation # TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor # merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE assert self.level == CompilationLevel.PIECEWISE
if not self.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
hash_key = vllm_config.compute_hash()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
from vllm.compilation.backends import InductorHashCache
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
self.cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info(
"Using cache directory: %s for vLLM's torch.compile",
self.cache_dir)
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config) return VllmBackend(vllm_config)
@@ -2520,6 +2866,67 @@ class VllmConfig:
init=True) # type: ignore init=True) # type: ignore
instance_id: str = "" instance_id: str = ""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
# summarize vllm config
vllm_factors: List[Any] = []
from vllm import __version__
vllm_factors.append(__version__)
if self.model_config:
vllm_factors.append(self.model_config.compute_hash())
if self.cache_config:
vllm_factors.append(self.cache_config.compute_hash())
if self.parallel_config:
vllm_factors.append(self.parallel_config.compute_hash())
if self.scheduler_config:
vllm_factors.append(self.scheduler_config.compute_hash())
if self.device_config:
vllm_factors.append(self.device_config.compute_hash())
if self.load_config:
vllm_factors.append(self.load_config.compute_hash())
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
if self.speculative_config:
vllm_factors.append(self.speculative_config.compute_hash())
if self.decoding_config:
vllm_factors.append(self.decoding_config.compute_hash())
if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash())
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash())
if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash())
factors.append(vllm_factors)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
return hash_str
def pad_for_cudagraph(self, batch_size: int) -> int: def pad_for_cudagraph(self, batch_size: int) -> int:
# if batch_size > self.compilation_config.max_capture_size, # if batch_size > self.compilation_config.max_capture_size,
# it should raise an IndexError. # it should raise an IndexError.

View File

@@ -71,6 +71,7 @@ if TYPE_CHECKING:
VLLM_USE_V1: bool = False VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
def get_default_cache_root(): def get_default_cache_root():
@@ -463,6 +464,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
"VLLM_LOG_BATCHSIZE_INTERVAL": "VLLM_LOG_BATCHSIZE_INTERVAL":
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
"VLLM_DISABLE_COMPILE_CACHE":
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
} }
# end-env-vars-definition # end-env-vars-definition