[torch.compile] fast inductor (#11108)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
youkaichao
2024-12-16 16:15:22 -08:00
committed by GitHub
parent c301616ed2
commit 88a412ed3d
3 changed files with 624 additions and 7 deletions

View File

@@ -3,6 +3,7 @@ import copy
import enum
import hashlib
import json
import os
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
@@ -162,6 +163,30 @@ class ModelConfig:
which allows no processors.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors.append(self.model)
factors.append(self.dtype)
factors.append(self.quantization)
factors.append(self.quantization_param_path)
factors.append(self.revision)
factors.append(self.code_revision)
factors.append(self.trust_remote_code)
factors.append(self.rope_scaling)
factors.append(self.rope_theta)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __init__(self,
model: str,
task: Union[TaskOption, Literal["draft"]],
@@ -203,6 +228,8 @@ class ModelConfig:
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
if hf_overrides is None:
hf_overrides = {}
@@ -832,6 +859,24 @@ class CacheConfig:
cpu_offload_gb: Size of the CPU offload buffer in GiB.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors.append(self.cache_dtype)
# `cpu_offload_gb` does not use `torch.compile` yet.
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __init__(
self,
block_size: int,
@@ -928,6 +973,24 @@ class TokenizerPoolConfig:
pool_type: Union[str, Type["BaseTokenizerGroup"]]
extra_config: dict
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self):
if self.pool_type not in ("ray", ) and not isinstance(
self.pool_type, type):
@@ -1010,6 +1073,24 @@ class LoadConfig:
default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
@@ -1073,6 +1154,19 @@ class ParallelConfig:
rank: int = 0
def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors.append(self.pipeline_parallel_size)
factors.append(self.tensor_parallel_size)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
@@ -1209,6 +1303,24 @@ class SchedulerConfig:
chunked_prefill_enabled: bool = field(init=False)
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self) -> None:
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
@@ -1286,6 +1398,25 @@ class DeviceConfig:
device: Optional[torch.device]
device_type: str
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
@@ -1313,6 +1444,24 @@ class SpeculativeConfig:
decoding with top-1 proposals.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# spec decode does not use `torch.compile` yet.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@staticmethod
def maybe_create_spec_config(
target_model_config: ModelConfig,
@@ -1753,6 +1902,24 @@ class LoRAConfig:
long_lora_scaling_factors: Optional[Tuple[float]] = None
bias_enabled: bool = False
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# LoRA is not compatible with `torch.compile` .
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast
# majority of applications.
@@ -1802,6 +1969,24 @@ class PromptAdapterConfig:
max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self):
if self.max_prompt_adapters < 1:
@@ -1830,6 +2015,24 @@ class MultiModalConfig:
for each :class:`~vllm.multimodal.MultiModalPlugin`.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
# TODO: Add configs to init vision tower or not.
@@ -1869,6 +2072,24 @@ class PoolerConfig:
``math-shepherd-mistral-7b-prm`` model.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@staticmethod
def from_json(json_str: str) -> "PoolerConfig":
return PoolerConfig(**json.loads(json_str))
@@ -2103,6 +2324,24 @@ class DecodingConfig:
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend
@@ -2124,6 +2363,24 @@ class ObservabilityConfig:
# If set, collects the model execute time for the request.
collect_model_execute_time: bool = False
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def __post_init__(self):
if not is_otel_available() and self.otlp_traces_endpoint is not None:
raise ValueError(
@@ -2165,6 +2422,24 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection
kv_port: int = 14579
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@classmethod
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
"""Parse the CLI value for the kv cache transfer config."""
@@ -2234,6 +2509,9 @@ class CompilationConfig(BaseModel):
- 2: dynamo once.
- 3: piecewise compilation.
- debug_dump_path: the path to dump the debug information.
- cache_dir: the directory to store the compiled graph, to
accelerate Inductor compilation. By default, it will use
model-related information to generate a cache directory.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
@@ -2302,12 +2580,10 @@ class CompilationConfig(BaseModel):
""" # noqa
level: int = 0
debug_dump_path: str = ""
cache_dir: str = ""
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
])
splitting_ops: List[str] = Field(default=None) # type: ignore
use_inductor: bool = True
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
@@ -2371,12 +2647,37 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
compilation_time: float = PrivateAttr
# should be InductorHashCache, but Pydantic does not support it
inductor_hash_cache: Any = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors.append(self.level)
factors.append(self.backend)
factors.append(self.custom_ops)
factors.append(self.splitting_ops)
factors.append(self.use_inductor)
factors.append(self.inductor_compile_config)
factors.append(self.inductor_passes)
factors.append(self.pass_config.uuid())
return hashlib.sha256(str(factors).encode()).hexdigest()
def __repr__(self) -> str:
exclude = {
"static_forward_context",
@@ -2405,6 +2706,27 @@ class CompilationConfig(BaseModel):
count_all = self.custom_ops.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
if self.splitting_ops is None:
if envs.VLLM_USE_V1:
# v1 must split the graph on attention ops
# for piecewise cudagraph
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
else:
# v0 can use full graph compilation without splitting,
# splitting is optional.
# right now we still need it. kv cache shape
# will be included in the graph if we don't split
# the graph.
# TODO: hide kv cache in static forward context
# so that inductor does not see it.
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
@@ -2444,6 +2766,30 @@ class CompilationConfig(BaseModel):
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
if not self.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
hash_key = vllm_config.compute_hash()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
from vllm.compilation.backends import InductorHashCache
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
self.cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info(
"Using cache directory: %s for vLLM's torch.compile",
self.cache_dir)
from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config)
@@ -2520,6 +2866,67 @@ class VllmConfig:
init=True) # type: ignore
instance_id: str = ""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
# summarize vllm config
vllm_factors: List[Any] = []
from vllm import __version__
vllm_factors.append(__version__)
if self.model_config:
vllm_factors.append(self.model_config.compute_hash())
if self.cache_config:
vllm_factors.append(self.cache_config.compute_hash())
if self.parallel_config:
vllm_factors.append(self.parallel_config.compute_hash())
if self.scheduler_config:
vllm_factors.append(self.scheduler_config.compute_hash())
if self.device_config:
vllm_factors.append(self.device_config.compute_hash())
if self.load_config:
vllm_factors.append(self.load_config.compute_hash())
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
if self.speculative_config:
vllm_factors.append(self.speculative_config.compute_hash())
if self.decoding_config:
vllm_factors.append(self.decoding_config.compute_hash())
if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash())
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash())
if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash())
factors.append(vllm_factors)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
return hash_str
def pad_for_cudagraph(self, batch_size: int) -> int:
# if batch_size > self.compilation_config.max_capture_size,
# it should raise an IndexError.