Improve configs - the rest! (#17562)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
513
vllm/config.py
513
vllm/config.py
@@ -11,8 +11,8 @@ import textwrap
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
replace)
|
||||
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
|
||||
is_dataclass, replace)
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
@@ -20,7 +20,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
||||
Protocol, TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import deprecated
|
||||
@@ -57,7 +56,7 @@ if TYPE_CHECKING:
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
else:
|
||||
QuantizationConfig = None
|
||||
QuantizationConfig = Any
|
||||
ConfigType = type
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -169,6 +168,12 @@ def config(cls: ConfigT) -> ConfigT:
|
||||
"""
|
||||
A decorator that ensures all fields in a dataclass have default values
|
||||
and that each field has a docstring.
|
||||
|
||||
If a `ConfigT` is used as a CLI argument itself, the default value provided
|
||||
by `get_kwargs` will be the result parsing a JSON string as the kwargs
|
||||
(i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
|
||||
requires custom construction from CLI (i.e. `CompilationConfig`), it can
|
||||
have a `from_cli` method, which will be called instead.
|
||||
"""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The decorated class must be a dataclass.")
|
||||
@@ -202,7 +207,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
cls_fields = {f.name: f for f in fields(cls)}
|
||||
if name not in cls_fields:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
||||
named_field: Field = cls_fields.get(name)
|
||||
named_field: Field = cls_fields[name]
|
||||
if (default_factory := named_field.default_factory) is not MISSING:
|
||||
return field(default_factory=default_factory)
|
||||
if (default := named_field.default) is not MISSING:
|
||||
@@ -211,6 +216,10 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return next(f for f in fields(cls) if f.name == name).init
|
||||
|
||||
|
||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
|
||||
@@ -2007,13 +2016,13 @@ class SchedulerConfig:
|
||||
def __post_init__(self) -> None:
|
||||
if self.max_model_len is None:
|
||||
self.max_model_len = 8192
|
||||
logger.warning(
|
||||
logger.warning_once(
|
||||
"max_model_len was is not set. Defaulting to arbitrary value "
|
||||
"of %d.", self.max_model_len)
|
||||
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 128
|
||||
logger.warning(
|
||||
logger.warning_once(
|
||||
"max_num_seqs was is not set. Defaulting to arbitrary value "
|
||||
"of %d.", self.max_num_seqs)
|
||||
|
||||
@@ -2840,8 +2849,8 @@ class PromptAdapterConfig:
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
limit_per_prompt: dict[str, int] = get_field(ModelConfig,
|
||||
"limit_mm_per_prompt")
|
||||
limit_per_prompt: dict[str, int] = \
|
||||
cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt"))
|
||||
"""
|
||||
The maximum number of input items allowed per prompt for each modality.
|
||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||
@@ -3415,41 +3424,49 @@ class ObservabilityConfig:
|
||||
self.collect_detailed_traces[0].split(","))
|
||||
|
||||
|
||||
class KVTransferConfig(BaseModel):
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
# The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
kv_connector: Optional[str] = None
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
# The device used by kv connector to buffer the KV cache.
|
||||
# Currently only support 'cuda'.
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
# The buffer size for TorchDistributedConnector. Measured in number of
|
||||
# bytes. Recommended value: 1e9 (about 1GB).
|
||||
kv_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
# are 'kv_producer', 'kv_consumer', and 'both'.
|
||||
kv_role: Optional[str] = None
|
||||
kv_role: Optional[KVRole] = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'both'."""
|
||||
|
||||
# The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
# 0 for prefill instance, 1 for decode instance.
|
||||
# Currently only 1P1D is supported.
|
||||
kv_rank: Optional[int] = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
0 for prefill instance, 1 for decode instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
# The number of parallel instances for KV cache transfer. For
|
||||
# PyNcclConnector, this should be 2.
|
||||
kv_parallel_size: int = 1
|
||||
"""The number of parallel instances for KV cache transfer. For
|
||||
PyNcclConnector, this should be 2."""
|
||||
|
||||
# The KV connector ip, used to build distributed connection
|
||||
kv_ip: str = "127.0.0.1"
|
||||
"""The KV connector ip, used to build distributed connection."""
|
||||
|
||||
# The KV connector port, used to build distributed connection
|
||||
kv_port: int = 14579
|
||||
"""The KV connector port, used to build distributed connection."""
|
||||
|
||||
# any extra config that the connector may need
|
||||
kv_connector_extra_config: dict[str, Any] = {}
|
||||
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -3470,46 +3487,37 @@ class KVTransferConfig(BaseModel):
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
|
||||
"""Parse the CLI value for the kv cache transfer config."""
|
||||
return KVTransferConfig.model_validate_json(cli_value)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in [
|
||||
"kv_producer", "kv_consumer", "kv_both"
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are `kv_producer`, `kv_consumer`, "
|
||||
f"and `kv_both`")
|
||||
def __post_init__(self) -> None:
|
||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are {get_args(KVRole)}")
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||
"is set, supported roles are `kv_producer`, "
|
||||
"`kv_consumer`, and `kv_both`")
|
||||
f"is set, supported roles are {get_args(KVRole)}")
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]
|
||||
self.kv_role in get_args(KVRole)
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_producer", "kv_both"]
|
||||
self.kv_role in get_args(KVProducer)
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_consumer", "kv_both"]
|
||||
self.kv_role in get_args(KVConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
|
||||
|
||||
class KVEventsConfig(BaseModel):
|
||||
@config
|
||||
@dataclass
|
||||
class KVEventsConfig:
|
||||
"""Configuration for KV event publishing."""
|
||||
|
||||
enable_kv_cache_events: bool = False
|
||||
@@ -3548,11 +3556,6 @@ class KVEventsConfig(BaseModel):
|
||||
this topic to receive events.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
||||
"""Parse the CLI value for the event publisher config."""
|
||||
return KVEventsConfig.model_validate_json(cli_value)
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
@@ -3562,80 +3565,72 @@ class CompilationLevel:
|
||||
PIECEWISE = 3
|
||||
|
||||
|
||||
class CompilationConfig(BaseModel):
|
||||
"""
|
||||
Configuration for compilation.
|
||||
It has three parts:
|
||||
@config
|
||||
@dataclass
|
||||
class PassConfig:
|
||||
"""Configuration for custom Inductor passes.
|
||||
|
||||
This is separate from general `CompilationConfig` so that inductor passes
|
||||
don't all have access to full configuration - that would create a cycle as
|
||||
the `PassManager` is set as a property of config."""
|
||||
|
||||
dump_graph_stages: list[str] = field(default_factory=list)
|
||||
"""List of stages for which we want to dump the graph. Each pass defines
|
||||
its own stages (before, after, maybe in-between)."""
|
||||
dump_graph_dir: Path = Path(".")
|
||||
"""Directory to dump the graphs."""
|
||||
# TODO(luka) better pass enabling system.
|
||||
enable_fusion: bool = True
|
||||
"""Whether to enable the custom fusion pass."""
|
||||
enable_noop: bool = True
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
enable_sequence_parallelism: bool = False
|
||||
"""Whether to enable sequence parallelism."""
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
include = {
|
||||
"enable_fusion", "enable_noop", "enable_sequence_parallelism"
|
||||
}
|
||||
dict_ = {k: v for k, v in asdict(self).items() if k in include}
|
||||
return InductorPass.hash_dict(dict_)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.enable_noop and self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class CompilationConfig:
|
||||
"""Configuration for compilation. It has three parts:
|
||||
|
||||
- Top-level Compilation control:
|
||||
- level: the level of compilation.
|
||||
- 0: no compilation.
|
||||
- 1: dynamo as is.
|
||||
- 2: dynamo once.
|
||||
- 3: piecewise compilation.
|
||||
- debug_dump_path: the path to dump the debug information.
|
||||
- cache_dir: the directory to store the compiled graph, to
|
||||
accelerate Inductor compilation. By default, it will use
|
||||
model-related information to generate a cache directory.
|
||||
- backend: the backend for compilation. It needs to be a string.
|
||||
- "" (empty string): use the default backend.
|
||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||
- "full.module.name": a qualified name which can be used to import the backend function.
|
||||
We use string to avoid serialization issues when using compilation in a distributed setting.
|
||||
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
|
||||
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
|
||||
- custom_ops: fine-grained control over which custom ops to enable/disable.
|
||||
Use 'all' to enable all, 'none' to disable all.
|
||||
Also specify a list of custom op names to enable (prefixed with a '+'),
|
||||
or disable (prefixed with a '-').
|
||||
Examples:
|
||||
- 'all,-op1' to enable all except op1
|
||||
- 'none,+op1,+op2' to enable only op1 and op2
|
||||
By default, all custom ops are enabled when running without Inductor
|
||||
and disabled when running with Inductor (compile_level >= Inductor).
|
||||
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
|
||||
- {attr}`level`
|
||||
- {attr}`debug_dump_path`
|
||||
- {attr}`cache_dir`
|
||||
- {attr}`backend`
|
||||
- {attr}`custom_ops`
|
||||
- {attr}`splitting_ops`
|
||||
- CudaGraph capture:
|
||||
- use_cudagraph: whether to use cudagraph inside compilation.
|
||||
- False: cudagraph inside compilation is not used.
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
Note that this is orthogonal to the cudagraph capture logic
|
||||
outside of compilation.
|
||||
TODO: move outside cudagraph logic into compilation.
|
||||
torch.compile will handle cudagraph capture logic in the future.
|
||||
- cudagraph_capture_sizes: sizes to capture cudagraph.
|
||||
- None (default): capture sizes are inferred from vllm config.
|
||||
- list[int]: capture sizes are specified as given.
|
||||
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
|
||||
It means the first several runs will be treated as warmup runs.
|
||||
Only after that, the execution will be recorded, and the recorded
|
||||
cudagraph will be used for subsequent runs.
|
||||
- cudagraph_copy_inputs: whether to copy input tensors for
|
||||
cudagraph. If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True, and the compiler will copy the input to an
|
||||
internally managed buffer. Default is False.
|
||||
- full_cuda_graph: whether to use a full cuda graph for the entire forward
|
||||
pass rather than splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models.
|
||||
- {attr}`use_cudagraph`
|
||||
- {attr}`cudagraph_capture_sizes`
|
||||
- {attr}`cudagraph_num_of_warmups`
|
||||
- {attr}`cudagraph_copy_inputs`
|
||||
- {attr}`full_cuda_graph`
|
||||
- Inductor compilation:
|
||||
- use_inductor: whether to use inductor compilation.
|
||||
- False: inductor compilation is not used. graph runs in eager.
|
||||
- True: inductor compilation is used. one graph for symbolic shape
|
||||
is compiled. In addition, compile for compile_sizes,
|
||||
using configurations in inductor_compile_config.
|
||||
- compile_sizes: sizes to compile for inductor. In addition
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture.
|
||||
- inductor_compile_config: additional configurations for inductor.
|
||||
- None: use default configurations.
|
||||
- inductor_passes: additional passes for inductor. It is a dictionary
|
||||
from pass name to pass function qualified name. We use function
|
||||
name because the config uses json format. If we pass the config
|
||||
from Python, functions can also be passed directly via Python object
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||
- custom inductor passes: see PassConfig for more details
|
||||
- {attr}`use_inductor`
|
||||
- {attr}`compile_sizes`
|
||||
- {attr}`inductor_compile_config`
|
||||
- {attr}`inductor_passes`
|
||||
- custom inductor passes
|
||||
|
||||
Why we have different sizes for cudagraph and inductor:
|
||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||
@@ -3646,83 +3641,135 @@ class CompilationConfig(BaseModel):
|
||||
static shapes. However, we find the general shape compilation is
|
||||
sufficient for most cases. It might be beneficial to compile for
|
||||
certain small batchsizes, where inductor is good at optimizing.
|
||||
""" # noqa
|
||||
"""
|
||||
# Top-level Compilation control
|
||||
level: int = 0
|
||||
"""The level of compilation:
|
||||
|
||||
- 0: no compilation.
|
||||
- 1: dynamo as is.
|
||||
- 2: dynamo once.
|
||||
- 3: piecewise compilation."""
|
||||
debug_dump_path: str = ""
|
||||
"""The path to dump the debug information."""
|
||||
cache_dir: str = ""
|
||||
"""The directory to store the compiled graph, to accelerate Inductor
|
||||
compilation. By default, it will use model-related information to generate
|
||||
a cache directory."""
|
||||
backend: str = ""
|
||||
custom_ops: list[str] = Field(default_factory=list)
|
||||
splitting_ops: list[str] = Field(default=None) # type: ignore
|
||||
"""The backend for compilation. It needs to be a string:
|
||||
|
||||
- "" (empty string): use the default backend.
|
||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||
- "full.module.name": a qualified name which can be used to import the
|
||||
|
||||
backend function.
|
||||
We use string to avoid serialization issues when using compilation in a
|
||||
distributed setting. When the compilation level is 1 or 2, the backend is
|
||||
used for the compilation directly (it sees the whole graph). When the
|
||||
compilation level is 3, the backend is used for the piecewise compilation
|
||||
(it sees a part of the graph)."""
|
||||
custom_ops: list[str] = field(default_factory=list)
|
||||
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
|
||||
to enable all, 'none' to disable all. Also specify a list of custom op
|
||||
names to enable (prefixed with a '+'), or disable (prefixed with a '-').
|
||||
Examples:
|
||||
|
||||
- 'all,-op1' to enable all except op1
|
||||
- 'none,+op1,+op2' to enable only op1 and op2
|
||||
|
||||
By default, all custom ops are enabled when running without Inductor and
|
||||
disabled when running with Inductor (compile_level >= Inductor)."""
|
||||
splitting_ops: list[str] = field(default_factory=list)
|
||||
"""A list of ops to split the full graph into subgraphs, used in piecewise
|
||||
compilation."""
|
||||
|
||||
# Inductor capture
|
||||
use_inductor: bool = True
|
||||
compile_sizes: Optional[list[Union[int, str]]] = Field(default=None)
|
||||
inductor_compile_config: dict = Field(default_factory=dict)
|
||||
inductor_passes: dict[str, str] = Field(default_factory=dict)
|
||||
"""Whether to use inductor compilation:
|
||||
|
||||
- False: inductor compilation is not used. graph runs in eager.
|
||||
- True: inductor compilation is used. one graph for symbolic shape
|
||||
is compiled. In addition, compile for compile_sizes,
|
||||
using configurations in inductor_compile_config."""
|
||||
compile_sizes: Optional[list[Union[int, str]]] = None
|
||||
"""Sizes to compile for inductor. In addition
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture."""
|
||||
inductor_compile_config: dict = field(default_factory=dict)
|
||||
"""Additional configurations for inductor.
|
||||
- None: use default configurations."""
|
||||
inductor_passes: dict[str, str] = field(default_factory=dict)
|
||||
"""Additional passes for inductor. It is a dictionary
|
||||
from pass name to pass function qualified name. We use function
|
||||
name because the config uses JSON format. If we pass the config
|
||||
from Python, functions can also be passed directly via Python object
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
||||
|
||||
# CudaGraph compilation
|
||||
use_cudagraph: bool = False
|
||||
"""Whether to use cudagraph inside compilation.
|
||||
- False: cudagraph inside compilation is not used.
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
Note that this is orthogonal to the cudagraph capture logic
|
||||
outside of compilation.
|
||||
TODO: move outside cudagraph logic into compilation.
|
||||
torch.compile will handle cudagraph capture logic in the future."""
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
"""Number of warmup runs for cudagraph.
|
||||
It means the first several runs will be treated as warmup runs.
|
||||
Only after that, the execution will be recorded, and the recorded
|
||||
cudagraph will be used for subsequent runs."""
|
||||
cudagraph_capture_sizes: Optional[list[int]] = None
|
||||
"""Sizes to capture cudagraph.
|
||||
- None (default): capture sizes are inferred from vllm config.
|
||||
- list[int]: capture sizes are specified as given."""
|
||||
cudagraph_copy_inputs: bool = False
|
||||
"""Whether to copy input tensors for
|
||||
cudagraph. If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True, and the compiler will copy the input to an
|
||||
internally managed buffer. Default is False."""
|
||||
full_cuda_graph: bool = False
|
||||
"""whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs. Thus this
|
||||
flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models."""
|
||||
|
||||
class PassConfig(BaseModel):
|
||||
"""
|
||||
Configuration for custom Inductor passes.
|
||||
This is separate from general CompilationConfig so that inductor passes
|
||||
don't all have access to full configuration - that would create a cycle
|
||||
as the PassManager is set as a property of config.
|
||||
- dump_graph_stages: list of stages for which we want to dump the graph.
|
||||
Each pass defines its own stages (before, after, maybe in-between).
|
||||
- dump_graph_dir: directory to dump the graphs. Default is .
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
- enable_noop: whether to enable the custom no-op elimination pass.
|
||||
TODO(luka) better pass enabling system.
|
||||
- enable_sequence_parallelism: whether to enable sequence parallelism.
|
||||
"""
|
||||
dump_graph_stages: list[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
enable_noop: bool = True
|
||||
enable_sequence_parallelism: bool = False
|
||||
pass_config: PassConfig = field(default_factory=PassConfig)
|
||||
"""Custom inductor passes, see PassConfig for more details"""
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
|
||||
"enable_sequence_parallelism"})
|
||||
return InductorPass.hash_dict(dict_)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if not self.enable_noop and self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
|
||||
pass_config: PassConfig = Field(default_factory=PassConfig)
|
||||
|
||||
# not configurable, computed after init
|
||||
max_capture_size: int = PrivateAttr
|
||||
local_cache_dir: str = PrivateAttr # local cache dir for each rank
|
||||
# optimization:
|
||||
# Intuitively, bs_to_padded_graph_size should be dict[int, int].
|
||||
# since we know all keys are in a range [0, max_capture_size],
|
||||
# we can optimize it to list[int] for better lookup performance.
|
||||
bs_to_padded_graph_size: list[int] = PrivateAttr
|
||||
max_capture_size: int = field(default=None, init=False) # type: ignore
|
||||
"""not configurable, computed after init"""
|
||||
local_cache_dir: str = field(default=None, init=False) # type: ignore
|
||||
"""local cache dir for each rank"""
|
||||
bs_to_padded_graph_size: list[int] = field(
|
||||
default=None, # type: ignore
|
||||
init=False)
|
||||
"""optimization:
|
||||
Intuitively, bs_to_padded_graph_size should be dict[int, int].
|
||||
since we know all keys are in a range [0, max_capture_size],
|
||||
we can optimize it to list[int] for better lookup performance."""
|
||||
|
||||
# keep track of enabled and disabled custom ops
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||
traced_files: set[str] = PrivateAttr
|
||||
compilation_time: float = PrivateAttr
|
||||
enabled_custom_ops: Counter[str] = field(default_factory=Counter,
|
||||
init=False)
|
||||
"""custom ops that are enabled"""
|
||||
disabled_custom_ops: Counter[str] = field(default_factory=Counter,
|
||||
init=False)
|
||||
"""custom ops that are disabled"""
|
||||
traced_files: set[str] = field(default_factory=set, init=False)
|
||||
"""files that are traced for compilation"""
|
||||
compilation_time: float = field(default=0.0, init=False)
|
||||
"""time taken for compilation"""
|
||||
|
||||
# Per-model forward context
|
||||
# Map from layer name to layer objects that need to be accessed outside
|
||||
# model code, e.g., Attention, FusedMOE when dp_size>1.
|
||||
static_forward_context: dict[str, Any] = PrivateAttr
|
||||
static_forward_context: dict[str, Any] = field(default_factory=dict,
|
||||
init=False)
|
||||
"""Per-model forward context
|
||||
Map from layer name to layer objects that need to be accessed outside
|
||||
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -3757,7 +3804,17 @@ class CompilationConfig(BaseModel):
|
||||
"pass_config",
|
||||
"traced_files",
|
||||
}
|
||||
return self.model_dump_json(exclude=exclude, exclude_unset=True)
|
||||
include = dict()
|
||||
for k, v in asdict(self).items():
|
||||
if k in exclude:
|
||||
continue
|
||||
f = get_field(CompilationConfig, k)
|
||||
if (d := f.default) is not MISSING and d == v:
|
||||
continue
|
||||
if (df := f.default_factory) is not MISSING and df() == v:
|
||||
continue
|
||||
include[k] = v
|
||||
return json.dumps(include)
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@@ -3766,12 +3823,9 @@ class CompilationConfig(BaseModel):
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
if cli_value in ["0", "1", "2", "3"]:
|
||||
return cls(level=int(cli_value))
|
||||
# do not use `eval`, it is dangerous and can execute arbitrary code
|
||||
dict_value = ast.literal_eval(cli_value)
|
||||
return CompilationConfig.model_validate(dict_value)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
return cls(**json.loads(cli_value))
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
count_none = self.custom_ops.count("none")
|
||||
count_all = self.custom_ops.count("all")
|
||||
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||
@@ -3789,9 +3843,6 @@ class CompilationConfig(BaseModel):
|
||||
if KEY not in self.inductor_compile_config:
|
||||
self.inductor_compile_config[KEY] = False
|
||||
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = []
|
||||
|
||||
for k, v in self.inductor_passes.items():
|
||||
if not isinstance(v, str):
|
||||
assert callable(v), (
|
||||
@@ -3808,11 +3859,8 @@ class CompilationConfig(BaseModel):
|
||||
self.inductor_compile_config[k] = func if isinstance(
|
||||
func, InductorPass) else CallableInductorPass(func)
|
||||
|
||||
self.enabled_custom_ops = Counter()
|
||||
self.disabled_custom_ops = Counter()
|
||||
self.traced_files = set()
|
||||
self.static_forward_context = {}
|
||||
self.compilation_time = 0.0
|
||||
if isinstance(self.pass_config, dict):
|
||||
self.pass_config = PassConfig(**self.pass_config)
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
@@ -3899,39 +3947,67 @@ class CompilationConfig(BaseModel):
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class VllmConfig:
|
||||
"""Dataclass which contains all vllm-related configuration. This
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
"""
|
||||
|
||||
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||
parallel_config: ParallelConfig = field(default_factory=ParallelConfig,
|
||||
init=True)
|
||||
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig,
|
||||
init=True)
|
||||
device_config: DeviceConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
model_config: ModelConfig = field(default_factory=ModelConfig)
|
||||
"""Model configuration."""
|
||||
cache_config: CacheConfig = field(default_factory=CacheConfig)
|
||||
"""Cache configuration."""
|
||||
parallel_config: ParallelConfig = field(default_factory=ParallelConfig)
|
||||
"""Parallel configuration."""
|
||||
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig)
|
||||
"""Scheduler configuration."""
|
||||
device_config: DeviceConfig = field(default_factory=DeviceConfig)
|
||||
"""Device configuration."""
|
||||
load_config: LoadConfig = field(default_factory=LoadConfig)
|
||||
"""Load configuration."""
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
speculative_config: SpeculativeConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
"""LoRA configuration."""
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
"""Speculative decoding configuration."""
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
"""Decoding configuration."""
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
"""Observability configuration."""
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||
"""Prompt adapter configuration."""
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
compilation_config: CompilationConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
"""Quantization configuration."""
|
||||
compilation_config: CompilationConfig = field(
|
||||
default_factory=CompilationConfig)
|
||||
"""`torch.compile` configuration for the model.
|
||||
|
||||
When it is a number (0, 1, 2, 3), it will be interpreted as the
|
||||
optimization level.
|
||||
|
||||
NOTE: level 0 is the default level without any optimization. level 1 and 2
|
||||
are for internal testing only. level 3 is the recommended level for
|
||||
production.
|
||||
|
||||
Following the convention of traditional compilers, using `-O` without space
|
||||
is also supported. `-O3` is equivalent to `-O 3`.
|
||||
|
||||
You can specify the full compilation config like so:
|
||||
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||
"""
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: Optional[KVEventsConfig] = None
|
||||
"""The configurations for event publishing."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
additional_config: SupportsHash = field(default=None,
|
||||
init=True) # type: ignore
|
||||
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
|
||||
"""Additional config for specified platform. Different platforms may
|
||||
support different configs. Make sure the configs are valid for the platform
|
||||
you are using. Contents must be hashable."""
|
||||
instance_id: str = ""
|
||||
"""The ID of the vLLM instance."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -4012,7 +4088,14 @@ class VllmConfig:
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
vllm_factors.append(self.additional_config.compute_hash())
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
json.dumps(additional_config, sort_keys=True).encode(),
|
||||
usedforsecurity=False,
|
||||
).hexdigest()
|
||||
else:
|
||||
additional_config_hash = additional_config.compute_hash()
|
||||
vllm_factors.append(additional_config_hash)
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
factors.append(vllm_factors)
|
||||
|
||||
Reference in New Issue
Block a user